Skip to main content

agent_context/context/
actor.rs

1//! [`AgentContext`] Actor 及所有 [`Message`] 实现。
2//!
3//! 管理三区消息模型的状态,提供消息增删改查、模型对话、上下文压缩等操作。
4
5use kameo::prelude::*;
6
7use super::event::{
8    CompressStrategy, NotifyChange, NotifyCompressedForReply, RequestAppend, RequestClear,
9    RequestCompress, RequestCompressed, RequestEstimateTokens, RequestExtend, RequestFindByRole,
10    RequestFromJsonl, RequestGet, RequestImmutable, RequestIncremental, RequestInsert,
11    RequestIsEmpty, RequestLen, RequestMessages, RequestPop, RequestRemove, RequestRetain,
12    RequestSend, RequestSendStream, RequestToJsonl, RequestUpdate,
13};
14use super::stream::AgentSendStream;
15use super::types::ContextBackend;
16use crate::error::AgentError;
17
18type CompressEditorReply<M> = (Vec<M>, Vec<M>);
19type CompressEditorRecipient<M> = ReplyRecipient<NotifyCompressedForReply<M>, CompressEditorReply<M>>;
20use crate::message::ContextMessage;
21use crate::readonly::ReadOnly;
22
23// ---------------------------------------------------------------------------
24// AgentContext Actor
25// ---------------------------------------------------------------------------
26
27/// LLM 对话上下文管理器,kameo Actor。
28///
29/// 管理三区 + Scratch 消息模型(immutable → compressed → incremental → scratch),提供:
30/// - 消息增删改查([`RequestAppend`]、[`RequestUpdate`]、[`RequestRemove`] 等)
31/// - 对话发送([`RequestSend`]、[`RequestSendStream`]),支持通过 [`super::CommonOpts::scratch`] 追加临时元数据
32/// - 上下文压缩([`RequestCompress`])
33/// - Token 估算([`RequestEstimateTokens`])
34/// - 变更通知([`NotifyChange`]),通过 [`subscribe_change`](AgentContext::subscribe_change) 注册订阅者
35///
36/// ## 构造
37///
38/// ```ignore
39/// let ctx = AgentContext::new(backend, vec![])
40///     .subscribe_change(app_ref.recipient());
41/// let actor = AgentContext::spawn(ctx);
42/// ```
43#[derive(Actor)]
44pub struct AgentContext<B: ContextBackend> {
45    backend: B,
46    immutable: ReadOnly<B::Message>,
47    compressed: Vec<B::Message>,
48    incremental: Vec<B::Message>,
49    subscribers: Vec<Recipient<NotifyChange<B::Message>>>,
50    on_compressed: Option<CompressEditorRecipient<B::Message>>,
51}
52
53impl<B: ContextBackend> AgentContext<B> {
54    /// 创建新的上下文管理器。
55    ///
56    /// - `backend`: 实现了 [`ContextBackend`] 的 LLM 后端实例
57    /// - `immutable`: 初始不可变消息(系统提示词等),放入 immutable 区
58    pub fn new(backend: B, immutable: Vec<B::Message>) -> Self {
59        Self {
60            backend,
61            immutable: ReadOnly::from(immutable),
62            compressed: Vec::new(),
63            incremental: Vec::new(),
64            subscribers: Vec::new(),
65            on_compressed: None,
66        }
67    }
68
69    /// 订阅增量区变更通知。
70    ///
71    /// 每次对 incremental 区的写操作(追加/更新/插入/移除/清空等)都会通知所有订阅者。
72    /// 用于 CLI 实时展示、日志记录等场景。可多次调用注册多个订阅者。
73    pub fn subscribe_change(mut self, recipient: Recipient<NotifyChange<B::Message>>) -> Self {
74        self.subscribers.push(recipient);
75        self
76    }
77
78    /// 注册压缩消息或摘要后编辑者 Actor。
79    ///
80    /// 在 [`RequestCompress`] 生成摘要后、写入 compressed 区之前,
81    /// 将 [`NotifyCompressedForReply`] 发送给压缩消息或摘要后编辑者 Actor,压缩消息或摘要后编辑者返回修改后的 `(摘要, 保留)` 对。
82    /// 用于自定义后处理(如过滤、重新排序摘要内容)。
83    pub fn subscribe_compressed(mut self, recipient: CompressEditorRecipient<B::Message>) -> Self {
84        self.on_compressed = Some(recipient);
85        self
86    }
87
88    fn default_summary_prompt() -> String {
89        "请将以下对话历史压缩为简洁摘要,保留关键信息、决策和上下文。输出一条 system 消息。"
90            .to_string()
91    }
92
93    async fn notify_change(&self, event: NotifyChange<B::Message>) {
94        for subscriber in &self.subscribers {
95            if let Err(e) = subscriber.tell(event.clone()).send().await {
96                unreachable!("通知订阅者失败: {e:?}");
97            }
98        }
99    }
100
101    /// 检查上下文是否已满,如果已满且 `auto_compress` 则自动压缩。
102    async fn compress_if_full(&mut self, opts: &B::Opts) -> Result<(), AgentError> {
103        let common = opts.as_ref();
104        let all: Vec<B::Message> = self
105            .immutable
106            .iter()
107            .chain(self.compressed.iter())
108            .chain(self.incremental.iter())
109            .cloned()
110            .collect();
111        let tokens = self
112            .backend
113            .estimate_tokens(&all)
114            .await
115            .unwrap_or(usize::MAX);
116        if tokens < common.context_window {
117            return Ok(());
118        }
119        if !common.auto_compress {
120            return Err(AgentError::Context("上下文已满且未启用自动压缩".into()));
121        }
122        let total = self.incremental.len();
123        let keep = total / 2;
124        if total <= keep {
125            return Ok(());
126        }
127        let split = total - keep;
128        let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
129        if to_summarize.is_empty() {
130            return Ok(());
131        }
132        let mut summary_messages =
133            vec![self.backend.system_message(Self::default_summary_prompt())];
134        summary_messages.append(&mut self.compressed);
135        summary_messages.extend(to_summarize);
136        let response = self.backend.send(&summary_messages, opts).await?;
137        let raw_msgs = self
138            .backend
139            .extract_messages(std::slice::from_ref(&response))?;
140        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
141        let summary: Vec<B::Message> = request_msgs
142            .into_iter()
143            .map(|msg| self.backend.to_system_message(msg))
144            .collect();
145        let kept: Vec<B::Message> = self.incremental.drain(..).collect();
146        let (final_summary, final_kept) = if let Some(editor) = &self.on_compressed {
147            editor
148                .ask(NotifyCompressedForReply { summary, kept })
149                .send()
150                .await
151                .map_err(|e| AgentError::Context(e.to_string()))?
152        } else {
153            (summary, kept)
154        };
155        self.compressed = final_summary;
156        self.incremental = final_kept;
157        Ok(())
158    }
159}
160
161// ---------------------------------------------------------------------------
162// 变更消息
163// ---------------------------------------------------------------------------
164
165impl<B: ContextBackend> Message<RequestAppend<B::Message>> for AgentContext<B> {
166    type Reply = ();
167
168    async fn handle(
169        &mut self,
170        msg: RequestAppend<B::Message>,
171        _ctx: &mut Context<Self, Self::Reply>,
172    ) -> Self::Reply {
173        self.incremental.push(msg.message);
174        if let Some(last) = self.incremental.last().cloned() {
175            self.notify_change(NotifyChange::Appended(last)).await;
176        }
177    }
178}
179
180impl<B: ContextBackend> Message<RequestExtend<B::Message>> for AgentContext<B> {
181    type Reply = ();
182
183    async fn handle(
184        &mut self,
185        msg: RequestExtend<B::Message>,
186        _ctx: &mut Context<Self, Self::Reply>,
187    ) -> Self::Reply {
188        for m in msg.messages {
189            self.incremental.push(m);
190            if let Some(last) = self.incremental.last().cloned() {
191                self.notify_change(NotifyChange::Appended(last)).await;
192            }
193        }
194    }
195}
196
197impl<B: ContextBackend> Message<RequestUpdate<B::Message>> for AgentContext<B> {
198    type Reply = Result<(), AgentError>;
199
200    async fn handle(
201        &mut self,
202        msg: RequestUpdate<B::Message>,
203        _ctx: &mut Context<Self, Self::Reply>,
204    ) -> Self::Reply {
205        if msg.index >= self.incremental.len() {
206            return Err(AgentError::Context("索引越界".into()));
207        }
208        let old = std::mem::replace(&mut self.incremental[msg.index], msg.message);
209        self.notify_change(NotifyChange::Updated {
210            index: msg.index,
211            old,
212            new: self.incremental[msg.index].clone(),
213        })
214        .await;
215        Ok(())
216    }
217}
218
219impl<B: ContextBackend> Message<RequestInsert<B::Message>> for AgentContext<B> {
220    type Reply = Result<(), AgentError>;
221
222    async fn handle(
223        &mut self,
224        msg: RequestInsert<B::Message>,
225        _ctx: &mut Context<Self, Self::Reply>,
226    ) -> Self::Reply {
227        if msg.index > self.incremental.len() {
228            return Err(AgentError::Context("索引越界".into()));
229        }
230        self.incremental.insert(msg.index, msg.message);
231        self.notify_change(NotifyChange::Inserted {
232            index: msg.index,
233            message: self.incremental[msg.index].clone(),
234        })
235        .await;
236        Ok(())
237    }
238}
239
240impl<B: ContextBackend> Message<RequestRemove> for AgentContext<B> {
241    type Reply = Result<(), AgentError>;
242
243    async fn handle(
244        &mut self,
245        msg: RequestRemove,
246        _ctx: &mut Context<Self, Self::Reply>,
247    ) -> Self::Reply {
248        if msg.index >= self.incremental.len() {
249            return Err(AgentError::Context("索引越界".into()));
250        }
251        let removed = self.incremental.remove(msg.index);
252        self.notify_change(NotifyChange::Removed {
253            index: msg.index,
254            message: removed,
255        })
256        .await;
257        Ok(())
258    }
259}
260
261impl<B: ContextBackend> Message<RequestPop> for AgentContext<B> {
262    type Reply = Option<B::Message>;
263
264    async fn handle(
265        &mut self,
266        _msg: RequestPop,
267        _ctx: &mut Context<Self, Self::Reply>,
268    ) -> Self::Reply {
269        let popped = self.incremental.pop();
270        if let Some(ref msg) = popped {
271            self.notify_change(NotifyChange::Popped(msg.clone())).await;
272        }
273        popped
274    }
275}
276
277impl<B: ContextBackend> Message<RequestRetain> for AgentContext<B> {
278    type Reply = ();
279
280    async fn handle(
281        &mut self,
282        msg: RequestRetain,
283        _ctx: &mut Context<Self, Self::Reply>,
284    ) -> Self::Reply {
285        let mut removed = Vec::new();
286        let role = msg.role;
287        self.incremental.retain(|m| {
288            if m.role() == role {
289                true
290            } else {
291                removed.push(m.clone());
292                false
293            }
294        });
295        self.notify_change(NotifyChange::Retained { role, removed })
296            .await;
297    }
298}
299
300impl<B: ContextBackend> Message<RequestClear> for AgentContext<B> {
301    type Reply = ();
302
303    async fn handle(
304        &mut self,
305        _msg: RequestClear,
306        _ctx: &mut Context<Self, Self::Reply>,
307    ) -> Self::Reply {
308        if !self.incremental.is_empty() {
309            let removed = std::mem::take(&mut self.incremental);
310            self.notify_change(NotifyChange::Cleared { removed }).await;
311        }
312    }
313}
314
315// ---------------------------------------------------------------------------
316// 查询消息
317// ---------------------------------------------------------------------------
318
319impl<B: ContextBackend> Message<RequestLen> for AgentContext<B> {
320    type Reply = usize;
321
322    async fn handle(
323        &mut self,
324        _msg: RequestLen,
325        _ctx: &mut Context<Self, Self::Reply>,
326    ) -> Self::Reply {
327        self.immutable.len() + self.compressed.len() + self.incremental.len()
328    }
329}
330
331impl<B: ContextBackend> Message<RequestIsEmpty> for AgentContext<B> {
332    type Reply = bool;
333
334    async fn handle(
335        &mut self,
336        _msg: RequestIsEmpty,
337        _ctx: &mut Context<Self, Self::Reply>,
338    ) -> Self::Reply {
339        self.immutable.is_empty() && self.compressed.is_empty() && self.incremental.is_empty()
340    }
341}
342
343impl<B: ContextBackend> Message<RequestGet> for AgentContext<B> {
344    type Reply = Option<B::Message>;
345
346    async fn handle(
347        &mut self,
348        msg: RequestGet,
349        _ctx: &mut Context<Self, Self::Reply>,
350    ) -> Self::Reply {
351        let idx = msg.0;
352        let imm_len = self.immutable.len();
353        let comp_len = self.compressed.len();
354        if idx < imm_len {
355            Some(self.immutable[idx].clone())
356        } else if idx < imm_len + comp_len {
357            Some(self.compressed[idx - imm_len].clone())
358        } else if idx < imm_len + comp_len + self.incremental.len() {
359            Some(self.incremental[idx - imm_len - comp_len].clone())
360        } else {
361            None
362        }
363    }
364}
365
366impl<B: ContextBackend> Message<RequestMessages> for AgentContext<B> {
367    type Reply = Vec<B::Message>;
368
369    async fn handle(
370        &mut self,
371        _msg: RequestMessages,
372        _ctx: &mut Context<Self, Self::Reply>,
373    ) -> Self::Reply {
374        self.immutable
375            .iter()
376            .chain(self.compressed.iter())
377            .chain(self.incremental.iter())
378            .cloned()
379            .collect()
380    }
381}
382
383impl<B: ContextBackend> Message<RequestImmutable> for AgentContext<B> {
384    type Reply = Vec<B::Message>;
385
386    async fn handle(
387        &mut self,
388        _msg: RequestImmutable,
389        _ctx: &mut Context<Self, Self::Reply>,
390    ) -> Self::Reply {
391        self.immutable.to_vec()
392    }
393}
394
395impl<B: ContextBackend> Message<RequestCompressed> for AgentContext<B> {
396    type Reply = Vec<B::Message>;
397
398    async fn handle(
399        &mut self,
400        _msg: RequestCompressed,
401        _ctx: &mut Context<Self, Self::Reply>,
402    ) -> Self::Reply {
403        self.compressed.clone()
404    }
405}
406
407impl<B: ContextBackend> Message<RequestIncremental> for AgentContext<B> {
408    type Reply = Vec<B::Message>;
409
410    async fn handle(
411        &mut self,
412        _msg: RequestIncremental,
413        _ctx: &mut Context<Self, Self::Reply>,
414    ) -> Self::Reply {
415        self.incremental.clone()
416    }
417}
418
419impl<B: ContextBackend> Message<RequestFindByRole> for AgentContext<B> {
420    type Reply = Vec<B::Message>;
421
422    async fn handle(
423        &mut self,
424        msg: RequestFindByRole,
425        _ctx: &mut Context<Self, Self::Reply>,
426    ) -> Self::Reply {
427        self.immutable
428            .iter()
429            .chain(self.compressed.iter())
430            .chain(self.incremental.iter())
431            .filter(|m| m.role() == msg.0)
432            .cloned()
433            .collect()
434    }
435}
436
437// ---------------------------------------------------------------------------
438// 对话消息
439// ---------------------------------------------------------------------------
440
441impl<B: ContextBackend> Message<RequestSend<B::Opts>> for AgentContext<B> {
442    type Reply = Result<B::Response, AgentError>;
443
444    async fn handle(
445        &mut self,
446        msg: RequestSend<B::Opts>,
447        _ctx: &mut Context<Self, Self::Reply>,
448    ) -> Self::Reply {
449        self.compress_if_full(&msg.opts).await?;
450        let scratch = msg.opts.as_ref().scratch.clone();
451        let mut all_messages: Vec<B::Message> = self
452            .immutable
453            .iter()
454            .chain(self.compressed.iter())
455            .chain(self.incremental.iter())
456            .cloned()
457            .collect();
458        if let Some(content) = scratch {
459            all_messages.push(self.backend.system_message(content));
460        }
461        let response = self.backend.send(&all_messages, &msg.opts).await?;
462        let raw_msgs = self
463            .backend
464            .extract_messages(std::slice::from_ref(&response))?;
465        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
466        for msg in &request_msgs {
467            self.incremental.push(msg.clone());
468            self.notify_change(NotifyChange::Appended(msg.clone()))
469                .await;
470        }
471        Ok(response)
472    }
473}
474
475impl<B: ContextBackend + Clone> Message<RequestSendStream<B::Opts>> for AgentContext<B> {
476    type Reply = Result<AgentSendStream<B>, AgentError>;
477
478    async fn handle(
479        &mut self,
480        msg: RequestSendStream<B::Opts>,
481        _ctx: &mut Context<Self, Self::Reply>,
482    ) -> Self::Reply {
483        self.compress_if_full(&msg.opts).await?;
484        let scratch = msg.opts.as_ref().scratch.clone();
485        let mut all_messages: Vec<B::Message> = self
486            .immutable
487            .iter()
488            .chain(self.compressed.iter())
489            .chain(self.incremental.iter())
490            .cloned()
491            .collect();
492        if let Some(content) = scratch {
493            all_messages.push(self.backend.system_message(content));
494        }
495        let stream = self.backend.send_stream(all_messages, msg.opts);
496        Ok(AgentSendStream::new(stream))
497    }
498}
499
500// ---------------------------------------------------------------------------
501// 压缩消息
502// ---------------------------------------------------------------------------
503
504impl<B: ContextBackend> Message<RequestCompress<B::Opts>> for AgentContext<B> {
505    type Reply = Result<(), AgentError>;
506
507    async fn handle(
508        &mut self,
509        msg: RequestCompress<B::Opts>,
510        _ctx: &mut Context<Self, Self::Reply>,
511    ) -> Self::Reply {
512        match msg.strategy {
513            CompressStrategy::Summarize { keep, prompt } => {
514                let total = self.incremental.len();
515                if total > keep {
516                    let split = total - keep;
517                    let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
518                    if !to_summarize.is_empty() {
519                        let summary_prompt = prompt.unwrap_or_else(Self::default_summary_prompt);
520                        let mut summary_messages =
521                            vec![self.backend.system_message(summary_prompt)];
522                        summary_messages.append(&mut self.compressed);
523                        summary_messages.extend(to_summarize);
524                        let response = self.backend.send(&summary_messages, &msg.opts).await?;
525                        let raw_msgs = self
526                            .backend
527                            .extract_messages(std::slice::from_ref(&response))?;
528                        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
529                        let summary: Vec<B::Message> = request_msgs
530                            .into_iter()
531                            .map(|msg| self.backend.to_system_message(msg))
532                            .collect();
533                        let kept: Vec<B::Message> = self.incremental.drain(..).collect();
534                        let (final_summary, final_kept) = if let Some(editor) =
535                            &self.on_compressed
536                        {
537                            editor
538                                .ask(NotifyCompressedForReply { summary, kept })
539                                .send()
540                                .await
541                                .map_err(|e| AgentError::Context(e.to_string()))?
542                        } else {
543                            (summary, kept)
544                        };
545                        self.compressed = final_summary;
546                        self.incremental = final_kept;
547                    }
548                }
549                Ok(())
550            }
551        }
552    }
553}
554
555// ---------------------------------------------------------------------------
556// 工具消息
557// ---------------------------------------------------------------------------
558
559impl<B: ContextBackend> Message<RequestEstimateTokens> for AgentContext<B> {
560    type Reply = usize;
561
562    async fn handle(
563        &mut self,
564        _msg: RequestEstimateTokens,
565        _ctx: &mut Context<Self, Self::Reply>,
566    ) -> Self::Reply {
567        let all: Vec<B::Message> = self
568            .immutable
569            .iter()
570            .chain(self.compressed.iter())
571            .chain(self.incremental.iter())
572            .cloned()
573            .collect();
574        self.backend.estimate_tokens(&all).await.unwrap_or(0)
575    }
576}
577
578impl<B: ContextBackend> Message<RequestToJsonl> for AgentContext<B> {
579    type Reply = Result<String, AgentError>;
580
581    async fn handle(
582        &mut self,
583        _msg: RequestToJsonl,
584        _ctx: &mut Context<Self, Self::Reply>,
585    ) -> Self::Reply {
586        let lines: Vec<String> = self
587            .immutable
588            .iter()
589            .chain(self.compressed.iter())
590            .chain(self.incremental.iter())
591            .map(|m| self.backend.message_to_jsonl(m))
592            .collect::<Result<_, _>>()?;
593        Ok(lines.join("\n"))
594    }
595}
596
597impl<B: ContextBackend> Message<RequestFromJsonl> for AgentContext<B> {
598    type Reply = Result<(), AgentError>;
599
600    async fn handle(
601        &mut self,
602        msg: RequestFromJsonl,
603        _ctx: &mut Context<Self, Self::Reply>,
604    ) -> Self::Reply {
605        for line in msg.jsonl.lines() {
606            let line = line.trim();
607            if line.is_empty() {
608                continue;
609            }
610            let message: B::Message = self.backend.message_from_jsonl(line)?;
611            self.incremental.push(message.clone());
612            self.notify_change(NotifyChange::Appended(message)).await;
613        }
614        Ok(())
615    }
616}