1use std::collections::HashSet;
6
7use kameo::prelude::*;
8
9use super::events::{
10 CompressStrategy, NotifyChange, NotifyCompressedForReply, RequestAppend, RequestClear,
11 RequestCompress, RequestCompressed, RequestEstimateTokens, RequestExtend, RequestFindByRole,
12 RequestImportIncremental, RequestGet, RequestImmutable, RequestIncremental, RequestInsert,
13 RequestIsEmpty, RequestLen, RequestMessages, RequestPop, RequestRemove, RequestRetain,
14 RequestSend, RequestSendStream, RequestSubscribeChange, RequestSubscribeCompressed,
15 RequestExportIncremental, RequestExportAll, RequestUnsubscribeChange, RequestUnsubscribeCompressed, RequestUpdate,
16};
17use super::stream::AgentSendStream;
18use super::types::ContextBackend;
19use crate::error::AgentError;
20
21type CompressSubscriberReply<M> = (Vec<M>, Vec<M>);
22type CompressSubscriberRecipient<M> = ReplyRecipient<NotifyCompressedForReply<M>, CompressSubscriberReply<M>>;
23use crate::message::ContextMessage;
24use crate::readonly::ReadOnly;
25
26#[derive(Actor)]
47pub struct AgentContext<B: ContextBackend> {
48 backend: B,
49 immutable: ReadOnly<B::Message>,
50 compressed: Vec<B::Message>,
51 incremental: Vec<B::Message>,
52 subscribers: HashSet<Recipient<NotifyChange<B::Message>>>,
53 on_compressed: Option<CompressSubscriberRecipient<B::Message>>,
54}
55
56impl<B: ContextBackend> AgentContext<B> {
57 pub fn new(backend: B, immutable: Vec<B::Message>) -> Self {
62 Self {
63 backend,
64 immutable: ReadOnly::from(immutable),
65 compressed: Vec::new(),
66 incremental: Vec::new(),
67 subscribers: HashSet::new(),
68 on_compressed: None,
69 }
70 }
71
72 fn default_summary_prompt() -> String {
73 "请将以下对话历史压缩为简洁摘要,保留关键信息、决策和上下文。输出一条 system 消息。"
74 .to_string()
75 }
76
77 async fn notify_change(&mut self, event: NotifyChange<B::Message>) {
78 let mut dead = Vec::new();
79 for subscriber in &self.subscribers {
80 if let Err(_e) = subscriber.tell(event.clone()).send().await {
81 dead.push(subscriber.clone());
82 }
83 }
84 self.subscribers.retain(|s| !dead.contains(s));
85 }
86
87 async fn notify_compressed_subscriber(
89 &self,
90 summary: Vec<B::Message>,
91 kept: Vec<B::Message>,
92 ) -> Result<(Vec<B::Message>, Vec<B::Message>), AgentError> {
93 if let Some(subscriber) = &self.on_compressed {
94 subscriber
95 .ask(NotifyCompressedForReply {
96 summary,
97 kept,
98 })
99 .send()
100 .await
101 .map_err(|e| AgentError::Context(e.to_string()))
102 } else {
103 Ok((summary, kept))
104 }
105 }
106
107 async fn compress_if_full(&mut self, opts: &B::Opts) -> Result<(), AgentError> {
109 let common = opts.as_ref();
110 let mut all: Vec<B::Message> = self
111 .immutable
112 .iter()
113 .chain(self.compressed.iter())
114 .chain(self.incremental.iter())
115 .cloned()
116 .collect();
117 if let Some(ref scratch) = common.scratch {
118 all.push(self.backend.system_message(scratch.clone()));
119 }
120 let tokens = self
121 .backend
122 .estimate_tokens(&all)
123 .await
124 .unwrap_or(usize::MAX);
125 if tokens < common.context_window {
126 return Ok(());
127 }
128 if !common.auto_compress {
129 return Err(AgentError::Context("上下文已满且未启用自动压缩".into()));
130 }
131 let total = self.incremental.len();
132 let keep = total / 2;
133 if total <= keep {
134 return Ok(());
135 }
136 let split = total - keep;
137 let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
138 if to_summarize.is_empty() {
139 return Ok(());
140 }
141 let mut summary_messages =
142 vec![self.backend.system_message(Self::default_summary_prompt())];
143 summary_messages.append(&mut self.compressed);
144 summary_messages.extend(to_summarize);
145 let response = self.backend.send(&summary_messages, opts).await?;
146 let raw_msgs = self
147 .backend
148 .extract_messages(std::slice::from_ref(&response))?;
149 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
150 let summary: Vec<B::Message> = request_msgs
151 .into_iter()
152 .map(|msg| self.backend.to_system_message(msg))
153 .collect();
154 let kept: Vec<B::Message> = self.incremental.drain(..).collect();
155 let (final_summary, final_kept) =
156 self.notify_compressed_subscriber(summary, kept).await?;
157 self.compressed = final_summary;
158 self.incremental = final_kept;
159 Ok(())
160 }
161}
162
163impl<B: ContextBackend> Message<RequestAppend<B::Message>> for AgentContext<B> {
168 type Reply = ();
169
170 async fn handle(
171 &mut self,
172 msg: RequestAppend<B::Message>,
173 _ctx: &mut Context<Self, Self::Reply>,
174 ) -> Self::Reply {
175 self.incremental.push(msg.message);
176 if let Some(last) = self.incremental.last().cloned() {
177 self.notify_change(NotifyChange::Appended(last)).await;
178 }
179 }
180}
181
182impl<B: ContextBackend> Message<RequestExtend<B::Message>> for AgentContext<B> {
183 type Reply = ();
184
185 async fn handle(
186 &mut self,
187 msg: RequestExtend<B::Message>,
188 _ctx: &mut Context<Self, Self::Reply>,
189 ) -> Self::Reply {
190 for m in msg.messages {
191 self.incremental.push(m);
192 if let Some(last) = self.incremental.last().cloned() {
193 self.notify_change(NotifyChange::Appended(last)).await;
194 }
195 }
196 }
197}
198
199impl<B: ContextBackend> Message<RequestUpdate<B::Message>> for AgentContext<B> {
200 type Reply = Result<(), AgentError>;
201
202 async fn handle(
203 &mut self,
204 msg: RequestUpdate<B::Message>,
205 _ctx: &mut Context<Self, Self::Reply>,
206 ) -> Self::Reply {
207 if msg.index >= self.incremental.len() {
208 return Err(AgentError::Context("索引越界".into()));
209 }
210 let old = std::mem::replace(&mut self.incremental[msg.index], msg.message);
211 self.notify_change(NotifyChange::Updated {
212 index: msg.index,
213 old,
214 new: self.incremental[msg.index].clone(),
215 })
216 .await;
217 Ok(())
218 }
219}
220
221impl<B: ContextBackend> Message<RequestInsert<B::Message>> for AgentContext<B> {
222 type Reply = Result<(), AgentError>;
223
224 async fn handle(
225 &mut self,
226 msg: RequestInsert<B::Message>,
227 _ctx: &mut Context<Self, Self::Reply>,
228 ) -> Self::Reply {
229 if msg.index > self.incremental.len() {
230 return Err(AgentError::Context("索引越界".into()));
231 }
232 self.incremental.insert(msg.index, msg.message);
233 self.notify_change(NotifyChange::Inserted {
234 index: msg.index,
235 message: self.incremental[msg.index].clone(),
236 })
237 .await;
238 Ok(())
239 }
240}
241
242impl<B: ContextBackend> Message<RequestRemove> for AgentContext<B> {
243 type Reply = Result<(), AgentError>;
244
245 async fn handle(
246 &mut self,
247 msg: RequestRemove,
248 _ctx: &mut Context<Self, Self::Reply>,
249 ) -> Self::Reply {
250 if msg.index >= self.incremental.len() {
251 return Err(AgentError::Context("索引越界".into()));
252 }
253 let removed = self.incremental.remove(msg.index);
254 self.notify_change(NotifyChange::Removed {
255 index: msg.index,
256 message: removed,
257 })
258 .await;
259 Ok(())
260 }
261}
262
263impl<B: ContextBackend> Message<RequestPop> for AgentContext<B> {
264 type Reply = Option<B::Message>;
265
266 async fn handle(
267 &mut self,
268 _msg: RequestPop,
269 _ctx: &mut Context<Self, Self::Reply>,
270 ) -> Self::Reply {
271 let popped = self.incremental.pop();
272 if let Some(ref msg) = popped {
273 self.notify_change(NotifyChange::Popped(msg.clone())).await;
274 }
275 popped
276 }
277}
278
279impl<B: ContextBackend> Message<RequestRetain> for AgentContext<B> {
280 type Reply = ();
281
282 async fn handle(
283 &mut self,
284 msg: RequestRetain,
285 _ctx: &mut Context<Self, Self::Reply>,
286 ) -> Self::Reply {
287 let mut removed = Vec::new();
288 let role = msg.role;
289 self.incremental.retain(|m| {
290 if m.role() == role {
291 true
292 } else {
293 removed.push(m.clone());
294 false
295 }
296 });
297 self.notify_change(NotifyChange::Retained { role, removed })
298 .await;
299 }
300}
301
302impl<B: ContextBackend> Message<RequestClear> for AgentContext<B> {
303 type Reply = ();
304
305 async fn handle(
306 &mut self,
307 _msg: RequestClear,
308 _ctx: &mut Context<Self, Self::Reply>,
309 ) -> Self::Reply {
310 if !self.incremental.is_empty() {
311 let removed = std::mem::take(&mut self.incremental);
312 self.notify_change(NotifyChange::Cleared { removed }).await;
313 }
314 }
315}
316
317impl<B: ContextBackend> Message<RequestLen> for AgentContext<B> {
322 type Reply = usize;
323
324 async fn handle(
325 &mut self,
326 _msg: RequestLen,
327 _ctx: &mut Context<Self, Self::Reply>,
328 ) -> Self::Reply {
329 self.immutable.len() + self.compressed.len() + self.incremental.len()
330 }
331}
332
333impl<B: ContextBackend> Message<RequestIsEmpty> for AgentContext<B> {
334 type Reply = bool;
335
336 async fn handle(
337 &mut self,
338 _msg: RequestIsEmpty,
339 _ctx: &mut Context<Self, Self::Reply>,
340 ) -> Self::Reply {
341 self.immutable.is_empty() && self.compressed.is_empty() && self.incremental.is_empty()
342 }
343}
344
345impl<B: ContextBackend> Message<RequestGet> for AgentContext<B> {
346 type Reply = Option<B::Message>;
347
348 async fn handle(
349 &mut self,
350 msg: RequestGet,
351 _ctx: &mut Context<Self, Self::Reply>,
352 ) -> Self::Reply {
353 let idx = msg.0;
354 let imm_len = self.immutable.len();
355 let comp_len = self.compressed.len();
356 if idx < imm_len {
357 Some(self.immutable[idx].clone())
358 } else if idx < imm_len + comp_len {
359 Some(self.compressed[idx - imm_len].clone())
360 } else if idx < imm_len + comp_len + self.incremental.len() {
361 Some(self.incremental[idx - imm_len - comp_len].clone())
362 } else {
363 None
364 }
365 }
366}
367
368impl<B: ContextBackend> Message<RequestMessages> for AgentContext<B> {
369 type Reply = Vec<B::Message>;
370
371 async fn handle(
372 &mut self,
373 _msg: RequestMessages,
374 _ctx: &mut Context<Self, Self::Reply>,
375 ) -> Self::Reply {
376 self.immutable
377 .iter()
378 .chain(self.compressed.iter())
379 .chain(self.incremental.iter())
380 .cloned()
381 .collect()
382 }
383}
384
385impl<B: ContextBackend> Message<RequestImmutable> for AgentContext<B> {
386 type Reply = Vec<B::Message>;
387
388 async fn handle(
389 &mut self,
390 _msg: RequestImmutable,
391 _ctx: &mut Context<Self, Self::Reply>,
392 ) -> Self::Reply {
393 self.immutable.to_vec()
394 }
395}
396
397impl<B: ContextBackend> Message<RequestCompressed> for AgentContext<B> {
398 type Reply = Vec<B::Message>;
399
400 async fn handle(
401 &mut self,
402 _msg: RequestCompressed,
403 _ctx: &mut Context<Self, Self::Reply>,
404 ) -> Self::Reply {
405 self.compressed.clone()
406 }
407}
408
409impl<B: ContextBackend> Message<RequestIncremental> for AgentContext<B> {
410 type Reply = Vec<B::Message>;
411
412 async fn handle(
413 &mut self,
414 _msg: RequestIncremental,
415 _ctx: &mut Context<Self, Self::Reply>,
416 ) -> Self::Reply {
417 self.incremental.clone()
418 }
419}
420
421impl<B: ContextBackend> Message<RequestFindByRole> for AgentContext<B> {
422 type Reply = Vec<B::Message>;
423
424 async fn handle(
425 &mut self,
426 msg: RequestFindByRole,
427 _ctx: &mut Context<Self, Self::Reply>,
428 ) -> Self::Reply {
429 self.immutable
430 .iter()
431 .chain(self.compressed.iter())
432 .chain(self.incremental.iter())
433 .filter(|m| m.role() == msg.0)
434 .cloned()
435 .collect()
436 }
437}
438
439impl<B: ContextBackend> Message<RequestSend<B::Opts>> for AgentContext<B> {
444 type Reply = Result<B::Response, AgentError>;
445
446 async fn handle(
447 &mut self,
448 msg: RequestSend<B::Opts>,
449 _ctx: &mut Context<Self, Self::Reply>,
450 ) -> Self::Reply {
451 self.compress_if_full(&msg.opts).await?;
452 let scratch = msg.opts.as_ref().scratch.clone();
453 let mut all_messages: Vec<B::Message> = self
454 .immutable
455 .iter()
456 .chain(self.compressed.iter())
457 .chain(self.incremental.iter())
458 .cloned()
459 .collect();
460 if let Some(content) = scratch {
461 all_messages.push(self.backend.system_message(content));
462 }
463 let response = self.backend.send(&all_messages, &msg.opts).await?;
464 let raw_msgs = self
465 .backend
466 .extract_messages(std::slice::from_ref(&response))?;
467 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
468 for msg in &request_msgs {
469 self.incremental.push(msg.clone());
470 self.notify_change(NotifyChange::Appended(msg.clone()))
471 .await;
472 }
473 Ok(response)
474 }
475}
476
477impl<B: ContextBackend + Clone> Message<RequestSendStream<B::Opts>> for AgentContext<B> {
478 type Reply = Result<AgentSendStream<B>, AgentError>;
479
480 async fn handle(
481 &mut self,
482 msg: RequestSendStream<B::Opts>,
483 _ctx: &mut Context<Self, Self::Reply>,
484 ) -> Self::Reply {
485 self.compress_if_full(&msg.opts).await?;
486 let scratch = msg.opts.as_ref().scratch.clone();
487 let mut all_messages: Vec<B::Message> = self
488 .immutable
489 .iter()
490 .chain(self.compressed.iter())
491 .chain(self.incremental.iter())
492 .cloned()
493 .collect();
494 if let Some(content) = scratch {
495 all_messages.push(self.backend.system_message(content));
496 }
497 let stream = self.backend.send_stream(all_messages, msg.opts);
498 Ok(AgentSendStream::new(stream))
499 }
500}
501
502impl<B: ContextBackend> Message<RequestCompress<B::Opts>> for AgentContext<B> {
507 type Reply = Result<(), AgentError>;
508
509 async fn handle(
510 &mut self,
511 msg: RequestCompress<B::Opts>,
512 _ctx: &mut Context<Self, Self::Reply>,
513 ) -> Self::Reply {
514 match msg.strategy {
515 CompressStrategy::Summarize { keep, prompt } => {
516 let total = self.incremental.len();
517 if total > keep {
518 let split = total - keep;
519 let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
520 if !to_summarize.is_empty() {
521 let summary_prompt = prompt.unwrap_or_else(Self::default_summary_prompt);
522 let mut summary_messages =
523 vec![self.backend.system_message(summary_prompt)];
524 summary_messages.append(&mut self.compressed);
525 summary_messages.extend(to_summarize);
526 let response = self.backend.send(&summary_messages, &msg.opts).await?;
527 let raw_msgs = self
528 .backend
529 .extract_messages(std::slice::from_ref(&response))?;
530 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
531 let summary: Vec<B::Message> = request_msgs
532 .into_iter()
533 .map(|msg| self.backend.to_system_message(msg))
534 .collect();
535 let kept: Vec<B::Message> = self.incremental.drain(..).collect();
536 let (final_summary, final_kept) =
537 self.notify_compressed_subscriber(summary, kept).await?;
538 self.compressed = final_summary;
539 self.incremental = final_kept;
540 }
541 }
542 Ok(())
543 }
544 }
545 }
546}
547
548impl<B: ContextBackend> Message<RequestSubscribeChange<B::Message>> for AgentContext<B> {
553 type Reply = ();
554
555 async fn handle(
556 &mut self,
557 msg: RequestSubscribeChange<B::Message>,
558 _ctx: &mut Context<Self, Self::Reply>,
559 ) -> Self::Reply {
560 self.subscribers.insert(msg.recipient);
561 }
562}
563
564impl<B: ContextBackend> Message<RequestUnsubscribeChange<B::Message>> for AgentContext<B> {
565 type Reply = bool;
566
567 async fn handle(
568 &mut self,
569 msg: RequestUnsubscribeChange<B::Message>,
570 _ctx: &mut Context<Self, Self::Reply>,
571 ) -> Self::Reply {
572 self.subscribers.remove(&msg.recipient)
573 }
574}
575
576impl<B: ContextBackend> Message<RequestSubscribeCompressed<B::Message>> for AgentContext<B> {
577 type Reply = ();
578
579 async fn handle(
580 &mut self,
581 msg: RequestSubscribeCompressed<B::Message>,
582 _ctx: &mut Context<Self, Self::Reply>,
583 ) -> Self::Reply {
584 self.on_compressed = Some(msg.recipient);
585 }
586}
587
588impl<B: ContextBackend> Message<RequestUnsubscribeCompressed> for AgentContext<B> {
589 type Reply = ();
590
591 async fn handle(
592 &mut self,
593 _msg: RequestUnsubscribeCompressed,
594 _ctx: &mut Context<Self, Self::Reply>,
595 ) -> Self::Reply {
596 self.on_compressed = None;
597 }
598}
599
600impl<B: ContextBackend> Message<RequestEstimateTokens> for AgentContext<B> {
605 type Reply = usize;
606
607 async fn handle(
608 &mut self,
609 _msg: RequestEstimateTokens,
610 _ctx: &mut Context<Self, Self::Reply>,
611 ) -> Self::Reply {
612 let all: Vec<B::Message> = self
613 .immutable
614 .iter()
615 .chain(self.compressed.iter())
616 .chain(self.incremental.iter())
617 .cloned()
618 .collect();
619 self.backend.estimate_tokens(&all).await.unwrap_or(0)
620 }
621}
622
623impl<B: ContextBackend> Message<RequestExportAll> for AgentContext<B> {
624 type Reply = Result<String, AgentError>;
625
626 async fn handle(
627 &mut self,
628 _msg: RequestExportAll,
629 _ctx: &mut Context<Self, Self::Reply>,
630 ) -> Self::Reply {
631 let lines: Vec<String> = self
632 .immutable
633 .iter()
634 .chain(self.compressed.iter())
635 .chain(self.incremental.iter())
636 .map(|m| self.backend.message_to_json(m))
637 .collect::<Result<_, _>>()?;
638 Ok(lines.join("\n"))
639 }
640}
641
642impl<B: ContextBackend> Message<RequestExportIncremental> for AgentContext<B> {
643 type Reply = Result<String, AgentError>;
644
645 async fn handle(
646 &mut self,
647 _msg: RequestExportIncremental,
648 _ctx: &mut Context<Self, Self::Reply>,
649 ) -> Self::Reply {
650 let lines: Vec<String> = self
651 .incremental
652 .iter()
653 .map(|m| self.backend.message_to_json(m))
654 .collect::<Result<_, _>>()?;
655 Ok(lines.join("\n"))
656 }
657}
658
659impl<B: ContextBackend> Message<RequestImportIncremental> for AgentContext<B> {
660 type Reply = Result<(), AgentError>;
661
662 async fn handle(
663 &mut self,
664 msg: RequestImportIncremental,
665 _ctx: &mut Context<Self, Self::Reply>,
666 ) -> Self::Reply {
667 let mut messages = Vec::new();
668 for line in msg.json.lines() {
669 let line = line.trim();
670 if line.is_empty() {
671 continue;
672 }
673 messages.push(self.backend.message_from_json(line)?);
674 }
675 self.incremental.clear();
676 self.incremental = messages.clone();
677 self.notify_change(NotifyChange::Loaded { messages }).await;
678 Ok(())
679 }
680}