Skip to main content

agent_context/context/
actor.rs

1//! [`AgentContext`] Actor 及所有 [`Message`] 实现。
2//!
3//! 管理三区消息模型的状态,提供消息增删改查、模型对话、上下文压缩等操作。
4
5use std::collections::HashSet;
6
7use kameo::prelude::*;
8
9use super::events::{
10    CompressStrategy, NotifyChange, NotifyCompressedForReply, RequestAppend, RequestClear,
11    RequestCompress, RequestCompressed, RequestEstimateTokens, RequestExtend, RequestFindByRole,
12    RequestFromJsonl, RequestGet, RequestImmutable, RequestIncremental, RequestInsert,
13    RequestIsEmpty, RequestLen, RequestMessages, RequestPop, RequestRemove, RequestRetain,
14    RequestSend, RequestSendStream, RequestSubscribeChange, RequestSubscribeCompressed,
15    RequestToJsonl, RequestUnsubscribeChange, RequestUnsubscribeCompressed, RequestUpdate,
16};
17use super::stream::AgentSendStream;
18use super::types::ContextBackend;
19use crate::error::AgentError;
20
21type CompressSubscriberReply<M> = (Vec<M>, Vec<M>);
22type CompressSubscriberRecipient<M> = ReplyRecipient<NotifyCompressedForReply<M>, CompressSubscriberReply<M>>;
23use crate::message::ContextMessage;
24use crate::readonly::ReadOnly;
25
26// ---------------------------------------------------------------------------
27// AgentContext Actor
28// ---------------------------------------------------------------------------
29
30/// LLM 对话上下文管理器,kameo Actor。
31///
32/// 管理三区 + Scratch 消息模型(immutable → compressed → incremental → scratch),提供:
33/// - 消息增删改查([`RequestAppend`]、[`RequestUpdate`]、[`RequestRemove`] 等)
34/// - 对话发送([`RequestSend`]、[`RequestSendStream`]),支持通过 [`super::CommonOpts::scratch`] 追加临时元数据
35/// - 上下文压缩([`RequestCompress`])
36/// - Token 估算([`RequestEstimateTokens`])
37/// - 订阅管理([`RequestSubscribeChange`]、[`RequestUnsubscribeChange`]、
38///   [`RequestSubscribeCompressed`]、[`RequestUnsubscribeCompressed`])
39///
40/// ## 构造
41///
42/// ```ignore
43/// let actor = AgentContext::spawn(AgentContext::new(backend, vec![]));
44/// actor.ask(RequestSubscribeChange { recipient: app_ref.recipient() }).await?;
45/// ```
46#[derive(Actor)]
47pub struct AgentContext<B: ContextBackend> {
48    backend: B,
49    immutable: ReadOnly<B::Message>,
50    compressed: Vec<B::Message>,
51    incremental: Vec<B::Message>,
52    subscribers: HashSet<Recipient<NotifyChange<B::Message>>>,
53    on_compressed: Option<CompressSubscriberRecipient<B::Message>>,
54}
55
56impl<B: ContextBackend> AgentContext<B> {
57    /// 创建新的上下文管理器。
58    ///
59    /// - `backend`: 实现了 [`ContextBackend`] 的 LLM 后端实例
60    /// - `immutable`: 初始不可变消息(系统提示词等),放入 immutable 区
61    pub fn new(backend: B, immutable: Vec<B::Message>) -> Self {
62        Self {
63            backend,
64            immutable: ReadOnly::from(immutable),
65            compressed: Vec::new(),
66            incremental: Vec::new(),
67            subscribers: HashSet::new(),
68            on_compressed: None,
69        }
70    }
71
72    fn default_summary_prompt() -> String {
73        "请将以下对话历史压缩为简洁摘要,保留关键信息、决策和上下文。输出一条 system 消息。"
74            .to_string()
75    }
76
77    async fn notify_change(&mut self, event: NotifyChange<B::Message>) {
78        let mut dead = Vec::new();
79        for subscriber in &self.subscribers {
80            if let Err(_e) = subscriber.tell(event.clone()).send().await {
81                dead.push(subscriber.clone());
82            }
83        }
84        self.subscribers.retain(|s| !dead.contains(s));
85    }
86
87    /// 将摘要和保留消息交给压缩结果订阅者处理,若无订阅者则原样返回。
88    async fn notify_compressed_subscriber(
89        &self,
90        summary: Vec<B::Message>,
91        kept: Vec<B::Message>,
92    ) -> Result<(Vec<B::Message>, Vec<B::Message>), AgentError> {
93        if let Some(subscriber) = &self.on_compressed {
94            subscriber
95                .ask(NotifyCompressedForReply {
96                    summary,
97                    kept,
98                })
99                .send()
100                .await
101                .map_err(|e| AgentError::Context(e.to_string()))
102        } else {
103            Ok((summary, kept))
104        }
105    }
106
107    /// 检查上下文是否已满,如果已满且 `auto_compress` 则自动压缩。
108    async fn compress_if_full(&mut self, opts: &B::Opts) -> Result<(), AgentError> {
109        let common = opts.as_ref();
110        let all: Vec<B::Message> = self
111            .immutable
112            .iter()
113            .chain(self.compressed.iter())
114            .chain(self.incremental.iter())
115            .cloned()
116            .collect();
117        let tokens = self
118            .backend
119            .estimate_tokens(&all)
120            .await
121            .unwrap_or(usize::MAX);
122        if tokens < common.context_window {
123            return Ok(());
124        }
125        if !common.auto_compress {
126            return Err(AgentError::Context("上下文已满且未启用自动压缩".into()));
127        }
128        let total = self.incremental.len();
129        let keep = total / 2;
130        if total <= keep {
131            return Ok(());
132        }
133        let split = total - keep;
134        let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
135        if to_summarize.is_empty() {
136            return Ok(());
137        }
138        let mut summary_messages =
139            vec![self.backend.system_message(Self::default_summary_prompt())];
140        summary_messages.append(&mut self.compressed);
141        summary_messages.extend(to_summarize);
142        let response = self.backend.send(&summary_messages, opts).await?;
143        let raw_msgs = self
144            .backend
145            .extract_messages(std::slice::from_ref(&response))?;
146        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
147        let summary: Vec<B::Message> = request_msgs
148            .into_iter()
149            .map(|msg| self.backend.to_system_message(msg))
150            .collect();
151        let kept: Vec<B::Message> = self.incremental.drain(..).collect();
152        let (final_summary, final_kept) =
153            self.notify_compressed_subscriber(summary, kept).await?;
154        self.compressed = final_summary;
155        self.incremental = final_kept;
156        Ok(())
157    }
158}
159
160// ---------------------------------------------------------------------------
161// 变更消息
162// ---------------------------------------------------------------------------
163
164impl<B: ContextBackend> Message<RequestAppend<B::Message>> for AgentContext<B> {
165    type Reply = ();
166
167    async fn handle(
168        &mut self,
169        msg: RequestAppend<B::Message>,
170        _ctx: &mut Context<Self, Self::Reply>,
171    ) -> Self::Reply {
172        self.incremental.push(msg.message);
173        if let Some(last) = self.incremental.last().cloned() {
174            self.notify_change(NotifyChange::Appended(last)).await;
175        }
176    }
177}
178
179impl<B: ContextBackend> Message<RequestExtend<B::Message>> for AgentContext<B> {
180    type Reply = ();
181
182    async fn handle(
183        &mut self,
184        msg: RequestExtend<B::Message>,
185        _ctx: &mut Context<Self, Self::Reply>,
186    ) -> Self::Reply {
187        for m in msg.messages {
188            self.incremental.push(m);
189            if let Some(last) = self.incremental.last().cloned() {
190                self.notify_change(NotifyChange::Appended(last)).await;
191            }
192        }
193    }
194}
195
196impl<B: ContextBackend> Message<RequestUpdate<B::Message>> for AgentContext<B> {
197    type Reply = Result<(), AgentError>;
198
199    async fn handle(
200        &mut self,
201        msg: RequestUpdate<B::Message>,
202        _ctx: &mut Context<Self, Self::Reply>,
203    ) -> Self::Reply {
204        if msg.index >= self.incremental.len() {
205            return Err(AgentError::Context("索引越界".into()));
206        }
207        let old = std::mem::replace(&mut self.incremental[msg.index], msg.message);
208        self.notify_change(NotifyChange::Updated {
209            index: msg.index,
210            old,
211            new: self.incremental[msg.index].clone(),
212        })
213        .await;
214        Ok(())
215    }
216}
217
218impl<B: ContextBackend> Message<RequestInsert<B::Message>> for AgentContext<B> {
219    type Reply = Result<(), AgentError>;
220
221    async fn handle(
222        &mut self,
223        msg: RequestInsert<B::Message>,
224        _ctx: &mut Context<Self, Self::Reply>,
225    ) -> Self::Reply {
226        if msg.index > self.incremental.len() {
227            return Err(AgentError::Context("索引越界".into()));
228        }
229        self.incremental.insert(msg.index, msg.message);
230        self.notify_change(NotifyChange::Inserted {
231            index: msg.index,
232            message: self.incremental[msg.index].clone(),
233        })
234        .await;
235        Ok(())
236    }
237}
238
239impl<B: ContextBackend> Message<RequestRemove> for AgentContext<B> {
240    type Reply = Result<(), AgentError>;
241
242    async fn handle(
243        &mut self,
244        msg: RequestRemove,
245        _ctx: &mut Context<Self, Self::Reply>,
246    ) -> Self::Reply {
247        if msg.index >= self.incremental.len() {
248            return Err(AgentError::Context("索引越界".into()));
249        }
250        let removed = self.incremental.remove(msg.index);
251        self.notify_change(NotifyChange::Removed {
252            index: msg.index,
253            message: removed,
254        })
255        .await;
256        Ok(())
257    }
258}
259
260impl<B: ContextBackend> Message<RequestPop> for AgentContext<B> {
261    type Reply = Option<B::Message>;
262
263    async fn handle(
264        &mut self,
265        _msg: RequestPop,
266        _ctx: &mut Context<Self, Self::Reply>,
267    ) -> Self::Reply {
268        let popped = self.incremental.pop();
269        if let Some(ref msg) = popped {
270            self.notify_change(NotifyChange::Popped(msg.clone())).await;
271        }
272        popped
273    }
274}
275
276impl<B: ContextBackend> Message<RequestRetain> for AgentContext<B> {
277    type Reply = ();
278
279    async fn handle(
280        &mut self,
281        msg: RequestRetain,
282        _ctx: &mut Context<Self, Self::Reply>,
283    ) -> Self::Reply {
284        let mut removed = Vec::new();
285        let role = msg.role;
286        self.incremental.retain(|m| {
287            if m.role() == role {
288                true
289            } else {
290                removed.push(m.clone());
291                false
292            }
293        });
294        self.notify_change(NotifyChange::Retained { role, removed })
295            .await;
296    }
297}
298
299impl<B: ContextBackend> Message<RequestClear> for AgentContext<B> {
300    type Reply = ();
301
302    async fn handle(
303        &mut self,
304        _msg: RequestClear,
305        _ctx: &mut Context<Self, Self::Reply>,
306    ) -> Self::Reply {
307        if !self.incremental.is_empty() {
308            let removed = std::mem::take(&mut self.incremental);
309            self.notify_change(NotifyChange::Cleared { removed }).await;
310        }
311    }
312}
313
314// ---------------------------------------------------------------------------
315// 查询消息
316// ---------------------------------------------------------------------------
317
318impl<B: ContextBackend> Message<RequestLen> for AgentContext<B> {
319    type Reply = usize;
320
321    async fn handle(
322        &mut self,
323        _msg: RequestLen,
324        _ctx: &mut Context<Self, Self::Reply>,
325    ) -> Self::Reply {
326        self.immutable.len() + self.compressed.len() + self.incremental.len()
327    }
328}
329
330impl<B: ContextBackend> Message<RequestIsEmpty> for AgentContext<B> {
331    type Reply = bool;
332
333    async fn handle(
334        &mut self,
335        _msg: RequestIsEmpty,
336        _ctx: &mut Context<Self, Self::Reply>,
337    ) -> Self::Reply {
338        self.immutable.is_empty() && self.compressed.is_empty() && self.incremental.is_empty()
339    }
340}
341
342impl<B: ContextBackend> Message<RequestGet> for AgentContext<B> {
343    type Reply = Option<B::Message>;
344
345    async fn handle(
346        &mut self,
347        msg: RequestGet,
348        _ctx: &mut Context<Self, Self::Reply>,
349    ) -> Self::Reply {
350        let idx = msg.0;
351        let imm_len = self.immutable.len();
352        let comp_len = self.compressed.len();
353        if idx < imm_len {
354            Some(self.immutable[idx].clone())
355        } else if idx < imm_len + comp_len {
356            Some(self.compressed[idx - imm_len].clone())
357        } else if idx < imm_len + comp_len + self.incremental.len() {
358            Some(self.incremental[idx - imm_len - comp_len].clone())
359        } else {
360            None
361        }
362    }
363}
364
365impl<B: ContextBackend> Message<RequestMessages> for AgentContext<B> {
366    type Reply = Vec<B::Message>;
367
368    async fn handle(
369        &mut self,
370        _msg: RequestMessages,
371        _ctx: &mut Context<Self, Self::Reply>,
372    ) -> Self::Reply {
373        self.immutable
374            .iter()
375            .chain(self.compressed.iter())
376            .chain(self.incremental.iter())
377            .cloned()
378            .collect()
379    }
380}
381
382impl<B: ContextBackend> Message<RequestImmutable> for AgentContext<B> {
383    type Reply = Vec<B::Message>;
384
385    async fn handle(
386        &mut self,
387        _msg: RequestImmutable,
388        _ctx: &mut Context<Self, Self::Reply>,
389    ) -> Self::Reply {
390        self.immutable.to_vec()
391    }
392}
393
394impl<B: ContextBackend> Message<RequestCompressed> for AgentContext<B> {
395    type Reply = Vec<B::Message>;
396
397    async fn handle(
398        &mut self,
399        _msg: RequestCompressed,
400        _ctx: &mut Context<Self, Self::Reply>,
401    ) -> Self::Reply {
402        self.compressed.clone()
403    }
404}
405
406impl<B: ContextBackend> Message<RequestIncremental> for AgentContext<B> {
407    type Reply = Vec<B::Message>;
408
409    async fn handle(
410        &mut self,
411        _msg: RequestIncremental,
412        _ctx: &mut Context<Self, Self::Reply>,
413    ) -> Self::Reply {
414        self.incremental.clone()
415    }
416}
417
418impl<B: ContextBackend> Message<RequestFindByRole> for AgentContext<B> {
419    type Reply = Vec<B::Message>;
420
421    async fn handle(
422        &mut self,
423        msg: RequestFindByRole,
424        _ctx: &mut Context<Self, Self::Reply>,
425    ) -> Self::Reply {
426        self.immutable
427            .iter()
428            .chain(self.compressed.iter())
429            .chain(self.incremental.iter())
430            .filter(|m| m.role() == msg.0)
431            .cloned()
432            .collect()
433    }
434}
435
436// ---------------------------------------------------------------------------
437// 对话消息
438// ---------------------------------------------------------------------------
439
440impl<B: ContextBackend> Message<RequestSend<B::Opts>> for AgentContext<B> {
441    type Reply = Result<B::Response, AgentError>;
442
443    async fn handle(
444        &mut self,
445        msg: RequestSend<B::Opts>,
446        _ctx: &mut Context<Self, Self::Reply>,
447    ) -> Self::Reply {
448        self.compress_if_full(&msg.opts).await?;
449        let scratch = msg.opts.as_ref().scratch.clone();
450        let mut all_messages: Vec<B::Message> = self
451            .immutable
452            .iter()
453            .chain(self.compressed.iter())
454            .chain(self.incremental.iter())
455            .cloned()
456            .collect();
457        if let Some(content) = scratch {
458            all_messages.push(self.backend.system_message(content));
459        }
460        let response = self.backend.send(&all_messages, &msg.opts).await?;
461        let raw_msgs = self
462            .backend
463            .extract_messages(std::slice::from_ref(&response))?;
464        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
465        for msg in &request_msgs {
466            self.incremental.push(msg.clone());
467            self.notify_change(NotifyChange::Appended(msg.clone()))
468                .await;
469        }
470        Ok(response)
471    }
472}
473
474impl<B: ContextBackend + Clone> Message<RequestSendStream<B::Opts>> for AgentContext<B> {
475    type Reply = Result<AgentSendStream<B>, AgentError>;
476
477    async fn handle(
478        &mut self,
479        msg: RequestSendStream<B::Opts>,
480        _ctx: &mut Context<Self, Self::Reply>,
481    ) -> Self::Reply {
482        self.compress_if_full(&msg.opts).await?;
483        let scratch = msg.opts.as_ref().scratch.clone();
484        let mut all_messages: Vec<B::Message> = self
485            .immutable
486            .iter()
487            .chain(self.compressed.iter())
488            .chain(self.incremental.iter())
489            .cloned()
490            .collect();
491        if let Some(content) = scratch {
492            all_messages.push(self.backend.system_message(content));
493        }
494        let stream = self.backend.send_stream(all_messages, msg.opts);
495        Ok(AgentSendStream::new(stream))
496    }
497}
498
499// ---------------------------------------------------------------------------
500// 压缩消息
501// ---------------------------------------------------------------------------
502
503impl<B: ContextBackend> Message<RequestCompress<B::Opts>> for AgentContext<B> {
504    type Reply = Result<(), AgentError>;
505
506    async fn handle(
507        &mut self,
508        msg: RequestCompress<B::Opts>,
509        _ctx: &mut Context<Self, Self::Reply>,
510    ) -> Self::Reply {
511        match msg.strategy {
512            CompressStrategy::Summarize { keep, prompt } => {
513                let total = self.incremental.len();
514                if total > keep {
515                    let split = total - keep;
516                    let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
517                    if !to_summarize.is_empty() {
518                        let summary_prompt = prompt.unwrap_or_else(Self::default_summary_prompt);
519                        let mut summary_messages =
520                            vec![self.backend.system_message(summary_prompt)];
521                        summary_messages.append(&mut self.compressed);
522                        summary_messages.extend(to_summarize);
523                        let response = self.backend.send(&summary_messages, &msg.opts).await?;
524                        let raw_msgs = self
525                            .backend
526                            .extract_messages(std::slice::from_ref(&response))?;
527                        let request_msgs = self.backend.to_request_messages(raw_msgs)?;
528                        let summary: Vec<B::Message> = request_msgs
529                            .into_iter()
530                            .map(|msg| self.backend.to_system_message(msg))
531                            .collect();
532                        let kept: Vec<B::Message> = self.incremental.drain(..).collect();
533                        let (final_summary, final_kept) =
534                            self.notify_compressed_subscriber(summary, kept).await?;
535                        self.compressed = final_summary;
536                        self.incremental = final_kept;
537                    }
538                }
539                Ok(())
540            }
541        }
542    }
543}
544
545// ---------------------------------------------------------------------------
546// 订阅管理消息
547// ---------------------------------------------------------------------------
548
549impl<B: ContextBackend> Message<RequestSubscribeChange<B::Message>> for AgentContext<B> {
550    type Reply = ();
551
552    async fn handle(
553        &mut self,
554        msg: RequestSubscribeChange<B::Message>,
555        _ctx: &mut Context<Self, Self::Reply>,
556    ) -> Self::Reply {
557        self.subscribers.insert(msg.recipient);
558    }
559}
560
561impl<B: ContextBackend> Message<RequestUnsubscribeChange<B::Message>> for AgentContext<B> {
562    type Reply = bool;
563
564    async fn handle(
565        &mut self,
566        msg: RequestUnsubscribeChange<B::Message>,
567        _ctx: &mut Context<Self, Self::Reply>,
568    ) -> Self::Reply {
569        self.subscribers.remove(&msg.recipient)
570    }
571}
572
573impl<B: ContextBackend> Message<RequestSubscribeCompressed<B::Message>> for AgentContext<B> {
574    type Reply = ();
575
576    async fn handle(
577        &mut self,
578        msg: RequestSubscribeCompressed<B::Message>,
579        _ctx: &mut Context<Self, Self::Reply>,
580    ) -> Self::Reply {
581        self.on_compressed = Some(msg.recipient);
582    }
583}
584
585impl<B: ContextBackend> Message<RequestUnsubscribeCompressed> for AgentContext<B> {
586    type Reply = ();
587
588    async fn handle(
589        &mut self,
590        _msg: RequestUnsubscribeCompressed,
591        _ctx: &mut Context<Self, Self::Reply>,
592    ) -> Self::Reply {
593        self.on_compressed = None;
594    }
595}
596
597// ---------------------------------------------------------------------------
598// 工具消息
599// ---------------------------------------------------------------------------
600
601impl<B: ContextBackend> Message<RequestEstimateTokens> for AgentContext<B> {
602    type Reply = usize;
603
604    async fn handle(
605        &mut self,
606        _msg: RequestEstimateTokens,
607        _ctx: &mut Context<Self, Self::Reply>,
608    ) -> Self::Reply {
609        let all: Vec<B::Message> = self
610            .immutable
611            .iter()
612            .chain(self.compressed.iter())
613            .chain(self.incremental.iter())
614            .cloned()
615            .collect();
616        self.backend.estimate_tokens(&all).await.unwrap_or(0)
617    }
618}
619
620impl<B: ContextBackend> Message<RequestToJsonl> for AgentContext<B> {
621    type Reply = Result<String, AgentError>;
622
623    async fn handle(
624        &mut self,
625        _msg: RequestToJsonl,
626        _ctx: &mut Context<Self, Self::Reply>,
627    ) -> Self::Reply {
628        let lines: Vec<String> = self
629            .immutable
630            .iter()
631            .chain(self.compressed.iter())
632            .chain(self.incremental.iter())
633            .map(|m| self.backend.message_to_jsonl(m))
634            .collect::<Result<_, _>>()?;
635        Ok(lines.join("\n"))
636    }
637}
638
639impl<B: ContextBackend> Message<RequestFromJsonl> for AgentContext<B> {
640    type Reply = Result<(), AgentError>;
641
642    async fn handle(
643        &mut self,
644        msg: RequestFromJsonl,
645        _ctx: &mut Context<Self, Self::Reply>,
646    ) -> Self::Reply {
647        for line in msg.jsonl.lines() {
648            let line = line.trim();
649            if line.is_empty() {
650                continue;
651            }
652            let message: B::Message = self.backend.message_from_jsonl(line)?;
653            self.incremental.push(message.clone());
654            self.notify_change(NotifyChange::Appended(message)).await;
655        }
656        Ok(())
657    }
658}