Skip to main content

agent_context/context/
actor.rs

1//! [`AgentContext`] Actor 及所有 [`Message`] 实现。
2//!
3//! 管理三区消息模型的状态,提供消息增删改查、模型对话、上下文压缩等操作。
4
5use kameo::prelude::*;
6
7use super::events::{
8    CompressStrategy, NotifyCompressedForReply, RequestAppend, RequestClear,
9    RequestCompress, RequestCompressed, RequestEstimateTokens, RequestExtend, RequestFindByRole,
10    RequestImportIncremental, RequestGet, RequestImmutable, RequestIncremental, RequestInsert,
11    RequestIsEmpty, RequestLen, RequestMessages, RequestPop, RequestRemove, RequestRetain,
12    RequestSend, RequestSendStream, RequestSubscribeCompressed, RequestUnsubscribeCompressed,
13    RequestExportIncremental, RequestExportAll, RequestUpdate,
14};
15use super::stream::AgentSendStream;
16use super::types::ContextBackend;
17use crate::error::AgentError;
18
19type CompressSubscriberReply<M> = (Vec<M>, Vec<M>);
20type CompressSubscriberRecipient<M> = ReplyRecipient<NotifyCompressedForReply<M>, CompressSubscriberReply<M>>;
21
22use crate::message::ContextMessage;
23use crate::readonly::ReadOnly;
24
25// ---------------------------------------------------------------------------
26// AgentContext Actor
27// ---------------------------------------------------------------------------
28
29/// LLM 对话上下文管理器,kameo Actor。
30///
31/// 管理三区 + 可选 Scratch 消息模型(immutable → compressed → incremental → scratch),提供:
32/// - 消息增删改查([`RequestAppend`]、[`RequestUpdate`]、[`RequestRemove`] 等)
33/// - 对话发送([`RequestSend`]、[`RequestSendStream`]),支持通过 [`super::CommonOpts::scratch`] 追加临时元数据
34/// - 上下文压缩([`RequestCompress`])
35/// - Token 估算([`RequestEstimateTokens`])
36/// - 压缩订阅管理([`RequestSubscribeCompressed`]、[`RequestUnsubscribeCompressed`])
37///
38/// ## 构造
39///
40/// ```ignore
41/// let actor_ref = AgentContext::spawn(AgentContext::new(backend, vec![]));
42/// ```
43#[derive(Actor)]
44pub struct AgentContext<B: ContextBackend> {
45    backend: B,
46    immutable: ReadOnly<B::Message>,
47    compressed: Vec<B::Message>,
48    incremental: Vec<B::Message>,
49    on_compressed: Option<CompressSubscriberRecipient<B::Message>>,
50}
51
52impl<B: ContextBackend> AgentContext<B> {
53    /// 创建新的上下文管理器。
54    ///
55    /// - `backend`: 实现了 [`ContextBackend`] 的 LLM 后端实例
56    /// - `immutable`: 初始不可变消息(系统提示词等),放入 immutable 区
57    pub fn new(backend: B, immutable: Vec<B::Message>) -> Self {
58        Self {
59            backend,
60            immutable: ReadOnly::from(immutable),
61            compressed: Vec::new(),
62            incremental: Vec::new(),
63            on_compressed: None,
64        }
65    }
66
67    fn default_summary_prompt() -> String {
68        "请将以下对话历史压缩为简洁摘要,保留关键信息、决策和上下文。输出一条 system 消息。"
69            .to_string()
70    }
71
72
73    /// 将摘要和保留消息交给压缩结果订阅者处理,若无订阅者则原样返回。
74    async fn notify_compressed_subscriber(
75        &self,
76        summary: Vec<B::Message>,
77        kept: Vec<B::Message>,
78    ) -> Result<(Vec<B::Message>, Vec<B::Message>), AgentError> {
79        if let Some(subscriber) = &self.on_compressed {
80            subscriber
81                .ask(NotifyCompressedForReply {
82                    summary,
83                    kept,
84                })
85                .send()
86                .await
87                .map_err(|e| AgentError::Context(e.to_string()))
88        } else {
89            Ok((summary, kept))
90        }
91    }
92
93    /// 检查上下文是否已满,如果已满且 `auto_compress` 则自动压缩。
94    async fn compress_if_full(&mut self, opts: &B::Opts) -> Result<(), AgentError> {
95        let common = opts.as_ref();
96        let mut all: Vec<B::Message> = self
97            .immutable
98            .iter()
99            .chain(self.compressed.iter())
100            .chain(self.incremental.iter())
101            .cloned()
102            .collect();
103        if let Some(ref scratch) = common.scratch {
104            all.push(self.backend.system_message(scratch.clone()));
105        }
106        let tokens = self
107            .backend
108            .estimate_tokens(&all)
109            .await
110            .unwrap_or(usize::MAX);
111        if tokens < common.context_window {
112            return Ok(());
113        }
114        if !common.auto_compress {
115            return Err(AgentError::Context("上下文已满且未启用自动压缩".into()));
116        }
117        let total = self.incremental.len();
118        let keep = total / 2;
119        if total <= keep {
120            return Ok(());
121        }
122        let split = total - keep;
123        let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
124        if to_summarize.is_empty() {
125            return Ok(());
126        }
127        let mut summary_messages =
128            vec![self.backend.system_message(Self::default_summary_prompt())];
129        summary_messages.append(&mut self.compressed);
130        summary_messages.extend(to_summarize);
131        let response = self.backend.send(&summary_messages, opts).await?;
132        let raw_msgs = self
133            .backend
134            .extract_messages(std::slice::from_ref(&response))?;
135        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
136        let summary: Vec<B::Message> = request_msgs
137            .into_iter()
138            .map(|msg| self.backend.to_system_message(msg))
139            .collect();
140        let kept: Vec<B::Message> = self.incremental.drain(..).collect();
141        let (final_summary, final_kept) =
142            self.notify_compressed_subscriber(summary, kept).await?;
143        self.compressed = final_summary;
144        self.incremental = final_kept;
145        Ok(())
146    }
147}
148
149// ---------------------------------------------------------------------------
150// 变更消息
151// ---------------------------------------------------------------------------
152
153impl<B: ContextBackend> Message<RequestAppend<B::Message>> for AgentContext<B> {
154    type Reply = ();
155
156    async fn handle(
157        &mut self,
158        msg: RequestAppend<B::Message>,
159        _ctx: &mut Context<Self, Self::Reply>,
160    ) -> Self::Reply {
161        self.incremental.push(msg.message);
162    }
163}
164
165impl<B: ContextBackend> Message<RequestExtend<B::Message>> for AgentContext<B> {
166    type Reply = ();
167
168    async fn handle(
169        &mut self,
170        msg: RequestExtend<B::Message>,
171        _ctx: &mut Context<Self, Self::Reply>,
172    ) -> Self::Reply {
173        for m in msg.messages {
174            self.incremental.push(m);
175        }
176    }
177}
178
179impl<B: ContextBackend> Message<RequestUpdate<B::Message>> for AgentContext<B> {
180    type Reply = Result<(), AgentError>;
181
182    async fn handle(
183        &mut self,
184        msg: RequestUpdate<B::Message>,
185        _ctx: &mut Context<Self, Self::Reply>,
186    ) -> Self::Reply {
187        if msg.index >= self.incremental.len() {
188            return Err(AgentError::Context("索引越界".into()));
189        }
190        self.incremental[msg.index] = msg.message;
191        Ok(())
192    }
193}
194
195impl<B: ContextBackend> Message<RequestInsert<B::Message>> for AgentContext<B> {
196    type Reply = Result<(), AgentError>;
197
198    async fn handle(
199        &mut self,
200        msg: RequestInsert<B::Message>,
201        _ctx: &mut Context<Self, Self::Reply>,
202    ) -> Self::Reply {
203        if msg.index > self.incremental.len() {
204            return Err(AgentError::Context("索引越界".into()));
205        }
206        self.incremental.insert(msg.index, msg.message);
207        Ok(())
208    }
209}
210
211impl<B: ContextBackend> Message<RequestRemove> for AgentContext<B> {
212    type Reply = Result<(), AgentError>;
213
214    async fn handle(
215        &mut self,
216        msg: RequestRemove,
217        _ctx: &mut Context<Self, Self::Reply>,
218    ) -> Self::Reply {
219        if msg.index >= self.incremental.len() {
220            return Err(AgentError::Context("索引越界".into()));
221        }
222        self.incremental.remove(msg.index);
223        Ok(())
224    }
225}
226
227impl<B: ContextBackend> Message<RequestPop> for AgentContext<B> {
228    type Reply = Option<B::Message>;
229
230    async fn handle(
231        &mut self,
232        _msg: RequestPop,
233        _ctx: &mut Context<Self, Self::Reply>,
234    ) -> Self::Reply {
235        self.incremental.pop()
236    }
237}
238
239impl<B: ContextBackend> Message<RequestRetain> for AgentContext<B> {
240    type Reply = ();
241
242    async fn handle(
243        &mut self,
244        msg: RequestRetain,
245        _ctx: &mut Context<Self, Self::Reply>,
246    ) -> Self::Reply {
247        let mut removed = Vec::new();
248        let role = msg.role;
249        self.incremental.retain(|m| {
250            if m.role() == role {
251                true
252            } else {
253                removed.push(m.clone());
254                false
255            }
256        });
257    }
258}
259
260impl<B: ContextBackend> Message<RequestClear> for AgentContext<B> {
261    type Reply = ();
262
263    async fn handle(
264        &mut self,
265        _msg: RequestClear,
266        _ctx: &mut Context<Self, Self::Reply>,
267    ) -> Self::Reply {
268        self.incremental.clear();
269    }
270}
271
272// ---------------------------------------------------------------------------
273// 查询消息
274// ---------------------------------------------------------------------------
275
276impl<B: ContextBackend> Message<RequestLen> for AgentContext<B> {
277    type Reply = usize;
278
279    async fn handle(
280        &mut self,
281        _msg: RequestLen,
282        _ctx: &mut Context<Self, Self::Reply>,
283    ) -> Self::Reply {
284        self.immutable.len() + self.compressed.len() + self.incremental.len()
285    }
286}
287
288impl<B: ContextBackend> Message<RequestIsEmpty> for AgentContext<B> {
289    type Reply = bool;
290
291    async fn handle(
292        &mut self,
293        _msg: RequestIsEmpty,
294        _ctx: &mut Context<Self, Self::Reply>,
295    ) -> Self::Reply {
296        self.immutable.is_empty() && self.compressed.is_empty() && self.incremental.is_empty()
297    }
298}
299
300impl<B: ContextBackend> Message<RequestGet> for AgentContext<B> {
301    type Reply = Option<B::Message>;
302
303    async fn handle(
304        &mut self,
305        msg: RequestGet,
306        _ctx: &mut Context<Self, Self::Reply>,
307    ) -> Self::Reply {
308        let idx = msg.0;
309        let imm_len = self.immutable.len();
310        let comp_len = self.compressed.len();
311        if idx < imm_len {
312            Some(self.immutable[idx].clone())
313        } else if idx < imm_len + comp_len {
314            Some(self.compressed[idx - imm_len].clone())
315        } else if idx < imm_len + comp_len + self.incremental.len() {
316            Some(self.incremental[idx - imm_len - comp_len].clone())
317        } else {
318            None
319        }
320    }
321}
322
323impl<B: ContextBackend> Message<RequestMessages> for AgentContext<B> {
324    type Reply = Vec<B::Message>;
325
326    async fn handle(
327        &mut self,
328        _msg: RequestMessages,
329        _ctx: &mut Context<Self, Self::Reply>,
330    ) -> Self::Reply {
331        self.immutable
332            .iter()
333            .chain(self.compressed.iter())
334            .chain(self.incremental.iter())
335            .cloned()
336            .collect()
337    }
338}
339
340impl<B: ContextBackend> Message<RequestImmutable> for AgentContext<B> {
341    type Reply = Vec<B::Message>;
342
343    async fn handle(
344        &mut self,
345        _msg: RequestImmutable,
346        _ctx: &mut Context<Self, Self::Reply>,
347    ) -> Self::Reply {
348        self.immutable.to_vec()
349    }
350}
351
352impl<B: ContextBackend> Message<RequestCompressed> for AgentContext<B> {
353    type Reply = Vec<B::Message>;
354
355    async fn handle(
356        &mut self,
357        _msg: RequestCompressed,
358        _ctx: &mut Context<Self, Self::Reply>,
359    ) -> Self::Reply {
360        self.compressed.clone()
361    }
362}
363
364impl<B: ContextBackend> Message<RequestIncremental> for AgentContext<B> {
365    type Reply = Vec<B::Message>;
366
367    async fn handle(
368        &mut self,
369        _msg: RequestIncremental,
370        _ctx: &mut Context<Self, Self::Reply>,
371    ) -> Self::Reply {
372        self.incremental.clone()
373    }
374}
375
376impl<B: ContextBackend> Message<RequestFindByRole> for AgentContext<B> {
377    type Reply = Vec<B::Message>;
378
379    async fn handle(
380        &mut self,
381        msg: RequestFindByRole,
382        _ctx: &mut Context<Self, Self::Reply>,
383    ) -> Self::Reply {
384        self.immutable
385            .iter()
386            .chain(self.compressed.iter())
387            .chain(self.incremental.iter())
388            .filter(|m| m.role() == msg.0)
389            .cloned()
390            .collect()
391    }
392}
393
394// ---------------------------------------------------------------------------
395// 对话消息
396// ---------------------------------------------------------------------------
397
398impl<B: ContextBackend> Message<RequestSend<B::Opts>> for AgentContext<B> {
399    type Reply = Result<B::Response, AgentError>;
400
401    async fn handle(
402        &mut self,
403        msg: RequestSend<B::Opts>,
404        _ctx: &mut Context<Self, Self::Reply>,
405    ) -> Self::Reply {
406        self.compress_if_full(&msg.opts).await?;
407        let scratch = msg.opts.as_ref().scratch.clone();
408        let mut all_messages: Vec<B::Message> = self
409            .immutable
410            .iter()
411            .chain(self.compressed.iter())
412            .chain(self.incremental.iter())
413            .cloned()
414            .collect();
415        if let Some(content) = scratch {
416            all_messages.push(self.backend.system_message(content));
417        }
418        let response = self.backend.send(&all_messages, &msg.opts).await?;
419        let raw_msgs = self
420            .backend
421            .extract_messages(std::slice::from_ref(&response))?;
422        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
423        for msg in &request_msgs {
424            self.incremental.push(msg.clone());
425        }
426        Ok(response)
427    }
428}
429
430impl<B: ContextBackend + Clone> Message<RequestSendStream<B::Opts>> for AgentContext<B> {
431    type Reply = Result<AgentSendStream<B>, AgentError>;
432
433    async fn handle(
434        &mut self,
435        msg: RequestSendStream<B::Opts>,
436        _ctx: &mut Context<Self, Self::Reply>,
437    ) -> Self::Reply {
438        self.compress_if_full(&msg.opts).await?;
439        let scratch = msg.opts.as_ref().scratch.clone();
440        let mut all_messages: Vec<B::Message> = self
441            .immutable
442            .iter()
443            .chain(self.compressed.iter())
444            .chain(self.incremental.iter())
445            .cloned()
446            .collect();
447        if let Some(content) = scratch {
448            all_messages.push(self.backend.system_message(content));
449        }
450        let stream = self.backend.send_stream(all_messages, msg.opts);
451        Ok(AgentSendStream::new(stream))
452    }
453}
454
455// ---------------------------------------------------------------------------
456// 压缩消息
457// ---------------------------------------------------------------------------
458
459impl<B: ContextBackend> Message<RequestCompress<B::Opts>> for AgentContext<B> {
460    type Reply = Result<(), AgentError>;
461
462    async fn handle(
463        &mut self,
464        msg: RequestCompress<B::Opts>,
465        _ctx: &mut Context<Self, Self::Reply>,
466    ) -> Self::Reply {
467        match msg.strategy {
468            CompressStrategy::Summarize { keep, prompt } => {
469                let total = self.incremental.len();
470                if total > keep {
471                    let split = total - keep;
472                    let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
473                    if !to_summarize.is_empty() {
474                        let summary_prompt = prompt.unwrap_or_else(Self::default_summary_prompt);
475                        let mut summary_messages =
476                            vec![self.backend.system_message(summary_prompt)];
477                        summary_messages.append(&mut self.compressed);
478                        summary_messages.extend(to_summarize);
479                        let response = self.backend.send(&summary_messages, &msg.opts).await?;
480                        let raw_msgs = self
481                            .backend
482                            .extract_messages(std::slice::from_ref(&response))?;
483                        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
484                        let summary: Vec<B::Message> = request_msgs
485                            .into_iter()
486                            .map(|msg| self.backend.to_system_message(msg))
487                            .collect();
488                        let kept: Vec<B::Message> = self.incremental.drain(..).collect();
489                        let (final_summary, final_kept) =
490                            self.notify_compressed_subscriber(summary, kept).await?;
491                        self.compressed = final_summary;
492                        self.incremental = final_kept;
493                    }
494                }
495                Ok(())
496            }
497        }
498    }
499}
500
501
502impl<B: ContextBackend> Message<RequestSubscribeCompressed<B::Message>> for AgentContext<B> {
503    type Reply = ();
504
505    async fn handle(
506        &mut self,
507        msg: RequestSubscribeCompressed<B::Message>,
508        _ctx: &mut Context<Self, Self::Reply>,
509    ) -> Self::Reply {
510        self.on_compressed = Some(msg.recipient);
511    }
512}
513
514impl<B: ContextBackend> Message<RequestUnsubscribeCompressed> for AgentContext<B> {
515    type Reply = ();
516
517    async fn handle(
518        &mut self,
519        _msg: RequestUnsubscribeCompressed,
520        _ctx: &mut Context<Self, Self::Reply>,
521    ) -> Self::Reply {
522        self.on_compressed = None;
523    }
524}
525
526// ---------------------------------------------------------------------------
527// 工具消息
528// ---------------------------------------------------------------------------
529
530impl<B: ContextBackend> Message<RequestEstimateTokens> for AgentContext<B> {
531    type Reply = usize;
532
533    async fn handle(
534        &mut self,
535        _msg: RequestEstimateTokens,
536        _ctx: &mut Context<Self, Self::Reply>,
537    ) -> Self::Reply {
538        let all: Vec<B::Message> = self
539            .immutable
540            .iter()
541            .chain(self.compressed.iter())
542            .chain(self.incremental.iter())
543            .cloned()
544            .collect();
545        self.backend.estimate_tokens(&all).await.unwrap_or(0)
546    }
547}
548
549impl<B: ContextBackend> Message<RequestExportAll> for AgentContext<B> {
550    type Reply = Result<String, AgentError>;
551
552    async fn handle(
553        &mut self,
554        _msg: RequestExportAll,
555        _ctx: &mut Context<Self, Self::Reply>,
556    ) -> Self::Reply {
557        let lines: Vec<String> = self
558            .immutable
559            .iter()
560            .chain(self.compressed.iter())
561            .chain(self.incremental.iter())
562            .map(|m| self.backend.message_to_json(m))
563            .collect::<Result<_, _>>()?;
564        Ok(lines.join("\n"))
565    }
566}
567
568impl<B: ContextBackend> Message<RequestExportIncremental> for AgentContext<B> {
569    type Reply = Result<String, AgentError>;
570
571    async fn handle(
572        &mut self,
573        _msg: RequestExportIncremental,
574        _ctx: &mut Context<Self, Self::Reply>,
575    ) -> Self::Reply {
576        let lines: Vec<String> = self
577            .incremental
578            .iter()
579            .map(|m| self.backend.message_to_json(m))
580            .collect::<Result<_, _>>()?;
581        Ok(lines.join("\n"))
582    }
583}
584
585impl<B: ContextBackend> Message<RequestImportIncremental> for AgentContext<B> {
586    type Reply = Result<(), AgentError>;
587
588    async fn handle(
589        &mut self,
590        msg: RequestImportIncremental,
591        _ctx: &mut Context<Self, Self::Reply>,
592    ) -> Self::Reply {
593        let mut messages = Vec::new();
594        for line in msg.json.lines() {
595            let line = line.trim();
596            if line.is_empty() {
597                continue;
598            }
599            messages.push(self.backend.message_from_json(line)?);
600        }
601        self.incremental.clear();
602        self.incremental = messages.clone();
603        Ok(())
604    }
605}