Skip to main content

agent_context/context/
actor.rs

1//! [`AgentContext`] Actor 及所有 [`Message`] 实现。
2//!
3//! 管理三区消息模型的状态,提供消息增删改查、模型对话、上下文压缩等操作。
4
5use std::collections::HashSet;
6
7use kameo::prelude::*;
8
9use super::events::{
10    CompressStrategy, NotifyChange, NotifyCompressedForReply, RequestAppend, RequestClear,
11    RequestCompress, RequestCompressed, RequestEstimateTokens, RequestExtend, RequestFindByRole,
12    RequestImportIncremental, RequestGet, RequestImmutable, RequestIncremental, RequestInsert,
13    RequestIsEmpty, RequestLen, RequestMessages, RequestPop, RequestRemove, RequestRetain,
14    RequestSend, RequestSendStream, RequestSubscribeChange, RequestSubscribeCompressed,
15    RequestExportIncremental, RequestExportAll, RequestUnsubscribeChange, RequestUnsubscribeCompressed, RequestUpdate,
16};
17use super::stream::AgentSendStream;
18use super::types::ContextBackend;
19use crate::error::AgentError;
20
21type CompressSubscriberReply<M> = (Vec<M>, Vec<M>);
22type CompressSubscriberRecipient<M> = ReplyRecipient<NotifyCompressedForReply<M>, CompressSubscriberReply<M>>;
23use crate::message::ContextMessage;
24use crate::readonly::ReadOnly;
25
26// ---------------------------------------------------------------------------
27// AgentContext Actor
28// ---------------------------------------------------------------------------
29
30/// LLM 对话上下文管理器,kameo Actor。
31///
32/// 管理三区 + 可选 Scratch 消息模型(immutable → compressed → incremental → scratch),提供:
33/// - 消息增删改查([`RequestAppend`]、[`RequestUpdate`]、[`RequestRemove`] 等)
34/// - 对话发送([`RequestSend`]、[`RequestSendStream`]),支持通过 [`super::CommonOpts::scratch`] 追加临时元数据
35/// - 上下文压缩([`RequestCompress`])
36/// - Token 估算([`RequestEstimateTokens`])
37/// - 订阅管理([`RequestSubscribeChange`]、[`RequestUnsubscribeChange`]、
38///   [`RequestSubscribeCompressed`]、[`RequestUnsubscribeCompressed`])
39///
40/// ## 构造
41///
42/// ```ignore
43/// let actor_ref = AgentContext::spawn(AgentContext::new(backend, vec![]));
44/// actor_ref.ask(RequestSubscribeChange { recipient: app_ref.recipient() }).await?;
45/// ```
46#[derive(Actor)]
47pub struct AgentContext<B: ContextBackend> {
48    backend: B,
49    immutable: ReadOnly<B::Message>,
50    compressed: Vec<B::Message>,
51    incremental: Vec<B::Message>,
52    subscribers: HashSet<Recipient<NotifyChange<B::Message>>>,
53    on_compressed: Option<CompressSubscriberRecipient<B::Message>>,
54}
55
56impl<B: ContextBackend> AgentContext<B> {
57    /// 创建新的上下文管理器。
58    ///
59    /// - `backend`: 实现了 [`ContextBackend`] 的 LLM 后端实例
60    /// - `immutable`: 初始不可变消息(系统提示词等),放入 immutable 区
61    pub fn new(backend: B, immutable: Vec<B::Message>) -> Self {
62        Self {
63            backend,
64            immutable: ReadOnly::from(immutable),
65            compressed: Vec::new(),
66            incremental: Vec::new(),
67            subscribers: HashSet::new(),
68            on_compressed: None,
69        }
70    }
71
72    fn default_summary_prompt() -> String {
73        "请将以下对话历史压缩为简洁摘要,保留关键信息、决策和上下文。输出一条 system 消息。"
74            .to_string()
75    }
76
77    async fn notify_change(&mut self, event: NotifyChange<B::Message>) {
78        let mut dead = Vec::new();
79        for subscriber in &self.subscribers {
80            if let Err(_e) = subscriber.tell(event.clone()).send().await {
81                dead.push(subscriber.clone());
82            }
83        }
84        self.subscribers.retain(|s| !dead.contains(s));
85    }
86
87    /// 将摘要和保留消息交给压缩结果订阅者处理,若无订阅者则原样返回。
88    async fn notify_compressed_subscriber(
89        &self,
90        summary: Vec<B::Message>,
91        kept: Vec<B::Message>,
92    ) -> Result<(Vec<B::Message>, Vec<B::Message>), AgentError> {
93        if let Some(subscriber) = &self.on_compressed {
94            subscriber
95                .ask(NotifyCompressedForReply {
96                    summary,
97                    kept,
98                })
99                .send()
100                .await
101                .map_err(|e| AgentError::Context(e.to_string()))
102        } else {
103            Ok((summary, kept))
104        }
105    }
106
107    /// 检查上下文是否已满,如果已满且 `auto_compress` 则自动压缩。
108    async fn compress_if_full(&mut self, opts: &B::Opts) -> Result<(), AgentError> {
109        let common = opts.as_ref();
110        let mut all: Vec<B::Message> = self
111            .immutable
112            .iter()
113            .chain(self.compressed.iter())
114            .chain(self.incremental.iter())
115            .cloned()
116            .collect();
117        if let Some(ref scratch) = common.scratch {
118            all.push(self.backend.system_message(scratch.clone()));
119        }
120        let tokens = self
121            .backend
122            .estimate_tokens(&all)
123            .await
124            .unwrap_or(usize::MAX);
125        if tokens < common.context_window {
126            return Ok(());
127        }
128        if !common.auto_compress {
129            return Err(AgentError::Context("上下文已满且未启用自动压缩".into()));
130        }
131        let total = self.incremental.len();
132        let keep = total / 2;
133        if total <= keep {
134            return Ok(());
135        }
136        let split = total - keep;
137        let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
138        if to_summarize.is_empty() {
139            return Ok(());
140        }
141        let mut summary_messages =
142            vec![self.backend.system_message(Self::default_summary_prompt())];
143        summary_messages.append(&mut self.compressed);
144        summary_messages.extend(to_summarize);
145        let response = self.backend.send(&summary_messages, opts).await?;
146        let raw_msgs = self
147            .backend
148            .extract_messages(std::slice::from_ref(&response))?;
149        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
150        let summary: Vec<B::Message> = request_msgs
151            .into_iter()
152            .map(|msg| self.backend.to_system_message(msg))
153            .collect();
154        let kept: Vec<B::Message> = self.incremental.drain(..).collect();
155        let (final_summary, final_kept) =
156            self.notify_compressed_subscriber(summary, kept).await?;
157        self.compressed = final_summary;
158        self.incremental = final_kept;
159        Ok(())
160    }
161}
162
163// ---------------------------------------------------------------------------
164// 变更消息
165// ---------------------------------------------------------------------------
166
167impl<B: ContextBackend> Message<RequestAppend<B::Message>> for AgentContext<B> {
168    type Reply = ();
169
170    async fn handle(
171        &mut self,
172        msg: RequestAppend<B::Message>,
173        _ctx: &mut Context<Self, Self::Reply>,
174    ) -> Self::Reply {
175        self.incremental.push(msg.message);
176        if let Some(last) = self.incremental.last().cloned() {
177            self.notify_change(NotifyChange::Appended(last)).await;
178        }
179    }
180}
181
182impl<B: ContextBackend> Message<RequestExtend<B::Message>> for AgentContext<B> {
183    type Reply = ();
184
185    async fn handle(
186        &mut self,
187        msg: RequestExtend<B::Message>,
188        _ctx: &mut Context<Self, Self::Reply>,
189    ) -> Self::Reply {
190        for m in msg.messages {
191            self.incremental.push(m);
192            if let Some(last) = self.incremental.last().cloned() {
193                self.notify_change(NotifyChange::Appended(last)).await;
194            }
195        }
196    }
197}
198
199impl<B: ContextBackend> Message<RequestUpdate<B::Message>> for AgentContext<B> {
200    type Reply = Result<(), AgentError>;
201
202    async fn handle(
203        &mut self,
204        msg: RequestUpdate<B::Message>,
205        _ctx: &mut Context<Self, Self::Reply>,
206    ) -> Self::Reply {
207        if msg.index >= self.incremental.len() {
208            return Err(AgentError::Context("索引越界".into()));
209        }
210        let old = std::mem::replace(&mut self.incremental[msg.index], msg.message);
211        self.notify_change(NotifyChange::Updated {
212            index: msg.index,
213            old,
214            new: self.incremental[msg.index].clone(),
215        })
216        .await;
217        Ok(())
218    }
219}
220
221impl<B: ContextBackend> Message<RequestInsert<B::Message>> for AgentContext<B> {
222    type Reply = Result<(), AgentError>;
223
224    async fn handle(
225        &mut self,
226        msg: RequestInsert<B::Message>,
227        _ctx: &mut Context<Self, Self::Reply>,
228    ) -> Self::Reply {
229        if msg.index > self.incremental.len() {
230            return Err(AgentError::Context("索引越界".into()));
231        }
232        self.incremental.insert(msg.index, msg.message);
233        self.notify_change(NotifyChange::Inserted {
234            index: msg.index,
235            message: self.incremental[msg.index].clone(),
236        })
237        .await;
238        Ok(())
239    }
240}
241
242impl<B: ContextBackend> Message<RequestRemove> for AgentContext<B> {
243    type Reply = Result<(), AgentError>;
244
245    async fn handle(
246        &mut self,
247        msg: RequestRemove,
248        _ctx: &mut Context<Self, Self::Reply>,
249    ) -> Self::Reply {
250        if msg.index >= self.incremental.len() {
251            return Err(AgentError::Context("索引越界".into()));
252        }
253        let removed = self.incremental.remove(msg.index);
254        self.notify_change(NotifyChange::Removed {
255            index: msg.index,
256            message: removed,
257        })
258        .await;
259        Ok(())
260    }
261}
262
263impl<B: ContextBackend> Message<RequestPop> for AgentContext<B> {
264    type Reply = Option<B::Message>;
265
266    async fn handle(
267        &mut self,
268        _msg: RequestPop,
269        _ctx: &mut Context<Self, Self::Reply>,
270    ) -> Self::Reply {
271        let popped = self.incremental.pop();
272        if let Some(ref msg) = popped {
273            self.notify_change(NotifyChange::Popped(msg.clone())).await;
274        }
275        popped
276    }
277}
278
279impl<B: ContextBackend> Message<RequestRetain> for AgentContext<B> {
280    type Reply = ();
281
282    async fn handle(
283        &mut self,
284        msg: RequestRetain,
285        _ctx: &mut Context<Self, Self::Reply>,
286    ) -> Self::Reply {
287        let mut removed = Vec::new();
288        let role = msg.role;
289        self.incremental.retain(|m| {
290            if m.role() == role {
291                true
292            } else {
293                removed.push(m.clone());
294                false
295            }
296        });
297        self.notify_change(NotifyChange::Retained { role, removed })
298            .await;
299    }
300}
301
302impl<B: ContextBackend> Message<RequestClear> for AgentContext<B> {
303    type Reply = ();
304
305    async fn handle(
306        &mut self,
307        _msg: RequestClear,
308        _ctx: &mut Context<Self, Self::Reply>,
309    ) -> Self::Reply {
310        if !self.incremental.is_empty() {
311            let removed = std::mem::take(&mut self.incremental);
312            self.notify_change(NotifyChange::Cleared { removed }).await;
313        }
314    }
315}
316
317// ---------------------------------------------------------------------------
318// 查询消息
319// ---------------------------------------------------------------------------
320
321impl<B: ContextBackend> Message<RequestLen> for AgentContext<B> {
322    type Reply = usize;
323
324    async fn handle(
325        &mut self,
326        _msg: RequestLen,
327        _ctx: &mut Context<Self, Self::Reply>,
328    ) -> Self::Reply {
329        self.immutable.len() + self.compressed.len() + self.incremental.len()
330    }
331}
332
333impl<B: ContextBackend> Message<RequestIsEmpty> for AgentContext<B> {
334    type Reply = bool;
335
336    async fn handle(
337        &mut self,
338        _msg: RequestIsEmpty,
339        _ctx: &mut Context<Self, Self::Reply>,
340    ) -> Self::Reply {
341        self.immutable.is_empty() && self.compressed.is_empty() && self.incremental.is_empty()
342    }
343}
344
345impl<B: ContextBackend> Message<RequestGet> for AgentContext<B> {
346    type Reply = Option<B::Message>;
347
348    async fn handle(
349        &mut self,
350        msg: RequestGet,
351        _ctx: &mut Context<Self, Self::Reply>,
352    ) -> Self::Reply {
353        let idx = msg.0;
354        let imm_len = self.immutable.len();
355        let comp_len = self.compressed.len();
356        if idx < imm_len {
357            Some(self.immutable[idx].clone())
358        } else if idx < imm_len + comp_len {
359            Some(self.compressed[idx - imm_len].clone())
360        } else if idx < imm_len + comp_len + self.incremental.len() {
361            Some(self.incremental[idx - imm_len - comp_len].clone())
362        } else {
363            None
364        }
365    }
366}
367
368impl<B: ContextBackend> Message<RequestMessages> for AgentContext<B> {
369    type Reply = Vec<B::Message>;
370
371    async fn handle(
372        &mut self,
373        _msg: RequestMessages,
374        _ctx: &mut Context<Self, Self::Reply>,
375    ) -> Self::Reply {
376        self.immutable
377            .iter()
378            .chain(self.compressed.iter())
379            .chain(self.incremental.iter())
380            .cloned()
381            .collect()
382    }
383}
384
385impl<B: ContextBackend> Message<RequestImmutable> for AgentContext<B> {
386    type Reply = Vec<B::Message>;
387
388    async fn handle(
389        &mut self,
390        _msg: RequestImmutable,
391        _ctx: &mut Context<Self, Self::Reply>,
392    ) -> Self::Reply {
393        self.immutable.to_vec()
394    }
395}
396
397impl<B: ContextBackend> Message<RequestCompressed> for AgentContext<B> {
398    type Reply = Vec<B::Message>;
399
400    async fn handle(
401        &mut self,
402        _msg: RequestCompressed,
403        _ctx: &mut Context<Self, Self::Reply>,
404    ) -> Self::Reply {
405        self.compressed.clone()
406    }
407}
408
409impl<B: ContextBackend> Message<RequestIncremental> for AgentContext<B> {
410    type Reply = Vec<B::Message>;
411
412    async fn handle(
413        &mut self,
414        _msg: RequestIncremental,
415        _ctx: &mut Context<Self, Self::Reply>,
416    ) -> Self::Reply {
417        self.incremental.clone()
418    }
419}
420
421impl<B: ContextBackend> Message<RequestFindByRole> for AgentContext<B> {
422    type Reply = Vec<B::Message>;
423
424    async fn handle(
425        &mut self,
426        msg: RequestFindByRole,
427        _ctx: &mut Context<Self, Self::Reply>,
428    ) -> Self::Reply {
429        self.immutable
430            .iter()
431            .chain(self.compressed.iter())
432            .chain(self.incremental.iter())
433            .filter(|m| m.role() == msg.0)
434            .cloned()
435            .collect()
436    }
437}
438
439// ---------------------------------------------------------------------------
440// 对话消息
441// ---------------------------------------------------------------------------
442
443impl<B: ContextBackend> Message<RequestSend<B::Opts>> for AgentContext<B> {
444    type Reply = Result<B::Response, AgentError>;
445
446    async fn handle(
447        &mut self,
448        msg: RequestSend<B::Opts>,
449        _ctx: &mut Context<Self, Self::Reply>,
450    ) -> Self::Reply {
451        self.compress_if_full(&msg.opts).await?;
452        let scratch = msg.opts.as_ref().scratch.clone();
453        let mut all_messages: Vec<B::Message> = self
454            .immutable
455            .iter()
456            .chain(self.compressed.iter())
457            .chain(self.incremental.iter())
458            .cloned()
459            .collect();
460        if let Some(content) = scratch {
461            all_messages.push(self.backend.system_message(content));
462        }
463        let response = self.backend.send(&all_messages, &msg.opts).await?;
464        let raw_msgs = self
465            .backend
466            .extract_messages(std::slice::from_ref(&response))?;
467        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
468        for msg in &request_msgs {
469            self.incremental.push(msg.clone());
470            self.notify_change(NotifyChange::Appended(msg.clone()))
471                .await;
472        }
473        Ok(response)
474    }
475}
476
477impl<B: ContextBackend + Clone> Message<RequestSendStream<B::Opts>> for AgentContext<B> {
478    type Reply = Result<AgentSendStream<B>, AgentError>;
479
480    async fn handle(
481        &mut self,
482        msg: RequestSendStream<B::Opts>,
483        _ctx: &mut Context<Self, Self::Reply>,
484    ) -> Self::Reply {
485        self.compress_if_full(&msg.opts).await?;
486        let scratch = msg.opts.as_ref().scratch.clone();
487        let mut all_messages: Vec<B::Message> = self
488            .immutable
489            .iter()
490            .chain(self.compressed.iter())
491            .chain(self.incremental.iter())
492            .cloned()
493            .collect();
494        if let Some(content) = scratch {
495            all_messages.push(self.backend.system_message(content));
496        }
497        let stream = self.backend.send_stream(all_messages, msg.opts);
498        Ok(AgentSendStream::new(stream))
499    }
500}
501
502// ---------------------------------------------------------------------------
503// 压缩消息
504// ---------------------------------------------------------------------------
505
506impl<B: ContextBackend> Message<RequestCompress<B::Opts>> for AgentContext<B> {
507    type Reply = Result<(), AgentError>;
508
509    async fn handle(
510        &mut self,
511        msg: RequestCompress<B::Opts>,
512        _ctx: &mut Context<Self, Self::Reply>,
513    ) -> Self::Reply {
514        match msg.strategy {
515            CompressStrategy::Summarize { keep, prompt } => {
516                let total = self.incremental.len();
517                if total > keep {
518                    let split = total - keep;
519                    let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
520                    if !to_summarize.is_empty() {
521                        let summary_prompt = prompt.unwrap_or_else(Self::default_summary_prompt);
522                        let mut summary_messages =
523                            vec![self.backend.system_message(summary_prompt)];
524                        summary_messages.append(&mut self.compressed);
525                        summary_messages.extend(to_summarize);
526                        let response = self.backend.send(&summary_messages, &msg.opts).await?;
527                        let raw_msgs = self
528                            .backend
529                            .extract_messages(std::slice::from_ref(&response))?;
530                        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
531                        let summary: Vec<B::Message> = request_msgs
532                            .into_iter()
533                            .map(|msg| self.backend.to_system_message(msg))
534                            .collect();
535                        let kept: Vec<B::Message> = self.incremental.drain(..).collect();
536                        let (final_summary, final_kept) =
537                            self.notify_compressed_subscriber(summary, kept).await?;
538                        self.compressed = final_summary;
539                        self.incremental = final_kept;
540                    }
541                }
542                Ok(())
543            }
544        }
545    }
546}
547
548// ---------------------------------------------------------------------------
549// 订阅管理消息
550// ---------------------------------------------------------------------------
551
552impl<B: ContextBackend> Message<RequestSubscribeChange<B::Message>> for AgentContext<B> {
553    type Reply = ();
554
555    async fn handle(
556        &mut self,
557        msg: RequestSubscribeChange<B::Message>,
558        _ctx: &mut Context<Self, Self::Reply>,
559    ) -> Self::Reply {
560        self.subscribers.insert(msg.recipient);
561    }
562}
563
564impl<B: ContextBackend> Message<RequestUnsubscribeChange<B::Message>> for AgentContext<B> {
565    type Reply = bool;
566
567    async fn handle(
568        &mut self,
569        msg: RequestUnsubscribeChange<B::Message>,
570        _ctx: &mut Context<Self, Self::Reply>,
571    ) -> Self::Reply {
572        self.subscribers.remove(&msg.recipient)
573    }
574}
575
576impl<B: ContextBackend> Message<RequestSubscribeCompressed<B::Message>> for AgentContext<B> {
577    type Reply = ();
578
579    async fn handle(
580        &mut self,
581        msg: RequestSubscribeCompressed<B::Message>,
582        _ctx: &mut Context<Self, Self::Reply>,
583    ) -> Self::Reply {
584        self.on_compressed = Some(msg.recipient);
585    }
586}
587
588impl<B: ContextBackend> Message<RequestUnsubscribeCompressed> for AgentContext<B> {
589    type Reply = ();
590
591    async fn handle(
592        &mut self,
593        _msg: RequestUnsubscribeCompressed,
594        _ctx: &mut Context<Self, Self::Reply>,
595    ) -> Self::Reply {
596        self.on_compressed = None;
597    }
598}
599
600// ---------------------------------------------------------------------------
601// 工具消息
602// ---------------------------------------------------------------------------
603
604impl<B: ContextBackend> Message<RequestEstimateTokens> for AgentContext<B> {
605    type Reply = usize;
606
607    async fn handle(
608        &mut self,
609        _msg: RequestEstimateTokens,
610        _ctx: &mut Context<Self, Self::Reply>,
611    ) -> Self::Reply {
612        let all: Vec<B::Message> = self
613            .immutable
614            .iter()
615            .chain(self.compressed.iter())
616            .chain(self.incremental.iter())
617            .cloned()
618            .collect();
619        self.backend.estimate_tokens(&all).await.unwrap_or(0)
620    }
621}
622
623impl<B: ContextBackend> Message<RequestExportAll> for AgentContext<B> {
624    type Reply = Result<String, AgentError>;
625
626    async fn handle(
627        &mut self,
628        _msg: RequestExportAll,
629        _ctx: &mut Context<Self, Self::Reply>,
630    ) -> Self::Reply {
631        let lines: Vec<String> = self
632            .immutable
633            .iter()
634            .chain(self.compressed.iter())
635            .chain(self.incremental.iter())
636            .map(|m| self.backend.message_to_json(m))
637            .collect::<Result<_, _>>()?;
638        Ok(lines.join("\n"))
639    }
640}
641
642impl<B: ContextBackend> Message<RequestExportIncremental> for AgentContext<B> {
643    type Reply = Result<String, AgentError>;
644
645    async fn handle(
646        &mut self,
647        _msg: RequestExportIncremental,
648        _ctx: &mut Context<Self, Self::Reply>,
649    ) -> Self::Reply {
650        let lines: Vec<String> = self
651            .incremental
652            .iter()
653            .map(|m| self.backend.message_to_json(m))
654            .collect::<Result<_, _>>()?;
655        Ok(lines.join("\n"))
656    }
657}
658
659impl<B: ContextBackend> Message<RequestImportIncremental> for AgentContext<B> {
660    type Reply = Result<(), AgentError>;
661
662    async fn handle(
663        &mut self,
664        msg: RequestImportIncremental,
665        _ctx: &mut Context<Self, Self::Reply>,
666    ) -> Self::Reply {
667        let mut messages = Vec::new();
668        for line in msg.json.lines() {
669            let line = line.trim();
670            if line.is_empty() {
671                continue;
672            }
673            messages.push(self.backend.message_from_json(line)?);
674        }
675        self.incremental.clear();
676        self.incremental = messages.clone();
677        self.notify_change(NotifyChange::Loaded { messages }).await;
678        Ok(())
679    }
680}