1use std::collections::HashSet;
6
7use kameo::prelude::*;
8
9use super::events::{
10 CompressStrategy, NotifyChange, NotifyCompressedForReply, RequestAppend, RequestClear,
11 RequestCompress, RequestCompressed, RequestEstimateTokens, RequestExtend, RequestFindByRole,
12 RequestFromJsonl, RequestGet, RequestImmutable, RequestIncremental, RequestInsert,
13 RequestIsEmpty, RequestLen, RequestMessages, RequestPop, RequestRemove, RequestRetain,
14 RequestSend, RequestSendStream, RequestSubscribeChange, RequestSubscribeCompressed,
15 RequestToJsonl, RequestUnsubscribeChange, RequestUnsubscribeCompressed, RequestUpdate,
16};
17use super::stream::AgentSendStream;
18use super::types::ContextBackend;
19use crate::error::AgentError;
20
21type CompressSubscriberReply<M> = (Vec<M>, Vec<M>);
22type CompressSubscriberRecipient<M> = ReplyRecipient<NotifyCompressedForReply<M>, CompressSubscriberReply<M>>;
23use crate::message::ContextMessage;
24use crate::readonly::ReadOnly;
25
26#[derive(Actor)]
47pub struct AgentContext<B: ContextBackend> {
48 backend: B,
49 immutable: ReadOnly<B::Message>,
50 compressed: Vec<B::Message>,
51 incremental: Vec<B::Message>,
52 subscribers: HashSet<Recipient<NotifyChange<B::Message>>>,
53 on_compressed: Option<CompressSubscriberRecipient<B::Message>>,
54}
55
56impl<B: ContextBackend> AgentContext<B> {
57 pub fn new(backend: B, immutable: Vec<B::Message>) -> Self {
62 Self {
63 backend,
64 immutable: ReadOnly::from(immutable),
65 compressed: Vec::new(),
66 incremental: Vec::new(),
67 subscribers: HashSet::new(),
68 on_compressed: None,
69 }
70 }
71
72 fn default_summary_prompt() -> String {
73 "请将以下对话历史压缩为简洁摘要,保留关键信息、决策和上下文。输出一条 system 消息。"
74 .to_string()
75 }
76
77 async fn notify_change(&mut self, event: NotifyChange<B::Message>) {
78 let mut dead = Vec::new();
79 for subscriber in &self.subscribers {
80 if let Err(_e) = subscriber.tell(event.clone()).send().await {
81 dead.push(subscriber.clone());
82 }
83 }
84 self.subscribers.retain(|s| !dead.contains(s));
85 }
86
87 async fn notify_compressed_subscriber(
89 &self,
90 summary: Vec<B::Message>,
91 kept: Vec<B::Message>,
92 ) -> Result<(Vec<B::Message>, Vec<B::Message>), AgentError> {
93 if let Some(subscriber) = &self.on_compressed {
94 subscriber
95 .ask(NotifyCompressedForReply {
96 summary,
97 kept,
98 })
99 .send()
100 .await
101 .map_err(|e| AgentError::Context(e.to_string()))
102 } else {
103 Ok((summary, kept))
104 }
105 }
106
107 async fn compress_if_full(&mut self, opts: &B::Opts) -> Result<(), AgentError> {
109 let common = opts.as_ref();
110 let all: Vec<B::Message> = self
111 .immutable
112 .iter()
113 .chain(self.compressed.iter())
114 .chain(self.incremental.iter())
115 .cloned()
116 .collect();
117 let tokens = self
118 .backend
119 .estimate_tokens(&all)
120 .await
121 .unwrap_or(usize::MAX);
122 if tokens < common.context_window {
123 return Ok(());
124 }
125 if !common.auto_compress {
126 return Err(AgentError::Context("上下文已满且未启用自动压缩".into()));
127 }
128 let total = self.incremental.len();
129 let keep = total / 2;
130 if total <= keep {
131 return Ok(());
132 }
133 let split = total - keep;
134 let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
135 if to_summarize.is_empty() {
136 return Ok(());
137 }
138 let mut summary_messages =
139 vec![self.backend.system_message(Self::default_summary_prompt())];
140 summary_messages.append(&mut self.compressed);
141 summary_messages.extend(to_summarize);
142 let response = self.backend.send(&summary_messages, opts).await?;
143 let raw_msgs = self
144 .backend
145 .extract_messages(std::slice::from_ref(&response))?;
146 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
147 let summary: Vec<B::Message> = request_msgs
148 .into_iter()
149 .map(|msg| self.backend.to_system_message(msg))
150 .collect();
151 let kept: Vec<B::Message> = self.incremental.drain(..).collect();
152 let (final_summary, final_kept) =
153 self.notify_compressed_subscriber(summary, kept).await?;
154 self.compressed = final_summary;
155 self.incremental = final_kept;
156 Ok(())
157 }
158}
159
160impl<B: ContextBackend> Message<RequestAppend<B::Message>> for AgentContext<B> {
165 type Reply = ();
166
167 async fn handle(
168 &mut self,
169 msg: RequestAppend<B::Message>,
170 _ctx: &mut Context<Self, Self::Reply>,
171 ) -> Self::Reply {
172 self.incremental.push(msg.message);
173 if let Some(last) = self.incremental.last().cloned() {
174 self.notify_change(NotifyChange::Appended(last)).await;
175 }
176 }
177}
178
179impl<B: ContextBackend> Message<RequestExtend<B::Message>> for AgentContext<B> {
180 type Reply = ();
181
182 async fn handle(
183 &mut self,
184 msg: RequestExtend<B::Message>,
185 _ctx: &mut Context<Self, Self::Reply>,
186 ) -> Self::Reply {
187 for m in msg.messages {
188 self.incremental.push(m);
189 if let Some(last) = self.incremental.last().cloned() {
190 self.notify_change(NotifyChange::Appended(last)).await;
191 }
192 }
193 }
194}
195
196impl<B: ContextBackend> Message<RequestUpdate<B::Message>> for AgentContext<B> {
197 type Reply = Result<(), AgentError>;
198
199 async fn handle(
200 &mut self,
201 msg: RequestUpdate<B::Message>,
202 _ctx: &mut Context<Self, Self::Reply>,
203 ) -> Self::Reply {
204 if msg.index >= self.incremental.len() {
205 return Err(AgentError::Context("索引越界".into()));
206 }
207 let old = std::mem::replace(&mut self.incremental[msg.index], msg.message);
208 self.notify_change(NotifyChange::Updated {
209 index: msg.index,
210 old,
211 new: self.incremental[msg.index].clone(),
212 })
213 .await;
214 Ok(())
215 }
216}
217
218impl<B: ContextBackend> Message<RequestInsert<B::Message>> for AgentContext<B> {
219 type Reply = Result<(), AgentError>;
220
221 async fn handle(
222 &mut self,
223 msg: RequestInsert<B::Message>,
224 _ctx: &mut Context<Self, Self::Reply>,
225 ) -> Self::Reply {
226 if msg.index > self.incremental.len() {
227 return Err(AgentError::Context("索引越界".into()));
228 }
229 self.incremental.insert(msg.index, msg.message);
230 self.notify_change(NotifyChange::Inserted {
231 index: msg.index,
232 message: self.incremental[msg.index].clone(),
233 })
234 .await;
235 Ok(())
236 }
237}
238
239impl<B: ContextBackend> Message<RequestRemove> for AgentContext<B> {
240 type Reply = Result<(), AgentError>;
241
242 async fn handle(
243 &mut self,
244 msg: RequestRemove,
245 _ctx: &mut Context<Self, Self::Reply>,
246 ) -> Self::Reply {
247 if msg.index >= self.incremental.len() {
248 return Err(AgentError::Context("索引越界".into()));
249 }
250 let removed = self.incremental.remove(msg.index);
251 self.notify_change(NotifyChange::Removed {
252 index: msg.index,
253 message: removed,
254 })
255 .await;
256 Ok(())
257 }
258}
259
260impl<B: ContextBackend> Message<RequestPop> for AgentContext<B> {
261 type Reply = Option<B::Message>;
262
263 async fn handle(
264 &mut self,
265 _msg: RequestPop,
266 _ctx: &mut Context<Self, Self::Reply>,
267 ) -> Self::Reply {
268 let popped = self.incremental.pop();
269 if let Some(ref msg) = popped {
270 self.notify_change(NotifyChange::Popped(msg.clone())).await;
271 }
272 popped
273 }
274}
275
276impl<B: ContextBackend> Message<RequestRetain> for AgentContext<B> {
277 type Reply = ();
278
279 async fn handle(
280 &mut self,
281 msg: RequestRetain,
282 _ctx: &mut Context<Self, Self::Reply>,
283 ) -> Self::Reply {
284 let mut removed = Vec::new();
285 let role = msg.role;
286 self.incremental.retain(|m| {
287 if m.role() == role {
288 true
289 } else {
290 removed.push(m.clone());
291 false
292 }
293 });
294 self.notify_change(NotifyChange::Retained { role, removed })
295 .await;
296 }
297}
298
299impl<B: ContextBackend> Message<RequestClear> for AgentContext<B> {
300 type Reply = ();
301
302 async fn handle(
303 &mut self,
304 _msg: RequestClear,
305 _ctx: &mut Context<Self, Self::Reply>,
306 ) -> Self::Reply {
307 if !self.incremental.is_empty() {
308 let removed = std::mem::take(&mut self.incremental);
309 self.notify_change(NotifyChange::Cleared { removed }).await;
310 }
311 }
312}
313
314impl<B: ContextBackend> Message<RequestLen> for AgentContext<B> {
319 type Reply = usize;
320
321 async fn handle(
322 &mut self,
323 _msg: RequestLen,
324 _ctx: &mut Context<Self, Self::Reply>,
325 ) -> Self::Reply {
326 self.immutable.len() + self.compressed.len() + self.incremental.len()
327 }
328}
329
330impl<B: ContextBackend> Message<RequestIsEmpty> for AgentContext<B> {
331 type Reply = bool;
332
333 async fn handle(
334 &mut self,
335 _msg: RequestIsEmpty,
336 _ctx: &mut Context<Self, Self::Reply>,
337 ) -> Self::Reply {
338 self.immutable.is_empty() && self.compressed.is_empty() && self.incremental.is_empty()
339 }
340}
341
342impl<B: ContextBackend> Message<RequestGet> for AgentContext<B> {
343 type Reply = Option<B::Message>;
344
345 async fn handle(
346 &mut self,
347 msg: RequestGet,
348 _ctx: &mut Context<Self, Self::Reply>,
349 ) -> Self::Reply {
350 let idx = msg.0;
351 let imm_len = self.immutable.len();
352 let comp_len = self.compressed.len();
353 if idx < imm_len {
354 Some(self.immutable[idx].clone())
355 } else if idx < imm_len + comp_len {
356 Some(self.compressed[idx - imm_len].clone())
357 } else if idx < imm_len + comp_len + self.incremental.len() {
358 Some(self.incremental[idx - imm_len - comp_len].clone())
359 } else {
360 None
361 }
362 }
363}
364
365impl<B: ContextBackend> Message<RequestMessages> for AgentContext<B> {
366 type Reply = Vec<B::Message>;
367
368 async fn handle(
369 &mut self,
370 _msg: RequestMessages,
371 _ctx: &mut Context<Self, Self::Reply>,
372 ) -> Self::Reply {
373 self.immutable
374 .iter()
375 .chain(self.compressed.iter())
376 .chain(self.incremental.iter())
377 .cloned()
378 .collect()
379 }
380}
381
382impl<B: ContextBackend> Message<RequestImmutable> for AgentContext<B> {
383 type Reply = Vec<B::Message>;
384
385 async fn handle(
386 &mut self,
387 _msg: RequestImmutable,
388 _ctx: &mut Context<Self, Self::Reply>,
389 ) -> Self::Reply {
390 self.immutable.to_vec()
391 }
392}
393
394impl<B: ContextBackend> Message<RequestCompressed> for AgentContext<B> {
395 type Reply = Vec<B::Message>;
396
397 async fn handle(
398 &mut self,
399 _msg: RequestCompressed,
400 _ctx: &mut Context<Self, Self::Reply>,
401 ) -> Self::Reply {
402 self.compressed.clone()
403 }
404}
405
406impl<B: ContextBackend> Message<RequestIncremental> for AgentContext<B> {
407 type Reply = Vec<B::Message>;
408
409 async fn handle(
410 &mut self,
411 _msg: RequestIncremental,
412 _ctx: &mut Context<Self, Self::Reply>,
413 ) -> Self::Reply {
414 self.incremental.clone()
415 }
416}
417
418impl<B: ContextBackend> Message<RequestFindByRole> for AgentContext<B> {
419 type Reply = Vec<B::Message>;
420
421 async fn handle(
422 &mut self,
423 msg: RequestFindByRole,
424 _ctx: &mut Context<Self, Self::Reply>,
425 ) -> Self::Reply {
426 self.immutable
427 .iter()
428 .chain(self.compressed.iter())
429 .chain(self.incremental.iter())
430 .filter(|m| m.role() == msg.0)
431 .cloned()
432 .collect()
433 }
434}
435
436impl<B: ContextBackend> Message<RequestSend<B::Opts>> for AgentContext<B> {
441 type Reply = Result<B::Response, AgentError>;
442
443 async fn handle(
444 &mut self,
445 msg: RequestSend<B::Opts>,
446 _ctx: &mut Context<Self, Self::Reply>,
447 ) -> Self::Reply {
448 self.compress_if_full(&msg.opts).await?;
449 let scratch = msg.opts.as_ref().scratch.clone();
450 let mut all_messages: Vec<B::Message> = self
451 .immutable
452 .iter()
453 .chain(self.compressed.iter())
454 .chain(self.incremental.iter())
455 .cloned()
456 .collect();
457 if let Some(content) = scratch {
458 all_messages.push(self.backend.system_message(content));
459 }
460 let response = self.backend.send(&all_messages, &msg.opts).await?;
461 let raw_msgs = self
462 .backend
463 .extract_messages(std::slice::from_ref(&response))?;
464 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
465 for msg in &request_msgs {
466 self.incremental.push(msg.clone());
467 self.notify_change(NotifyChange::Appended(msg.clone()))
468 .await;
469 }
470 Ok(response)
471 }
472}
473
474impl<B: ContextBackend + Clone> Message<RequestSendStream<B::Opts>> for AgentContext<B> {
475 type Reply = Result<AgentSendStream<B>, AgentError>;
476
477 async fn handle(
478 &mut self,
479 msg: RequestSendStream<B::Opts>,
480 _ctx: &mut Context<Self, Self::Reply>,
481 ) -> Self::Reply {
482 self.compress_if_full(&msg.opts).await?;
483 let scratch = msg.opts.as_ref().scratch.clone();
484 let mut all_messages: Vec<B::Message> = self
485 .immutable
486 .iter()
487 .chain(self.compressed.iter())
488 .chain(self.incremental.iter())
489 .cloned()
490 .collect();
491 if let Some(content) = scratch {
492 all_messages.push(self.backend.system_message(content));
493 }
494 let stream = self.backend.send_stream(all_messages, msg.opts);
495 Ok(AgentSendStream::new(stream))
496 }
497}
498
499impl<B: ContextBackend> Message<RequestCompress<B::Opts>> for AgentContext<B> {
504 type Reply = Result<(), AgentError>;
505
506 async fn handle(
507 &mut self,
508 msg: RequestCompress<B::Opts>,
509 _ctx: &mut Context<Self, Self::Reply>,
510 ) -> Self::Reply {
511 match msg.strategy {
512 CompressStrategy::Summarize { keep, prompt } => {
513 let total = self.incremental.len();
514 if total > keep {
515 let split = total - keep;
516 let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
517 if !to_summarize.is_empty() {
518 let summary_prompt = prompt.unwrap_or_else(Self::default_summary_prompt);
519 let mut summary_messages =
520 vec![self.backend.system_message(summary_prompt)];
521 summary_messages.append(&mut self.compressed);
522 summary_messages.extend(to_summarize);
523 let response = self.backend.send(&summary_messages, &msg.opts).await?;
524 let raw_msgs = self
525 .backend
526 .extract_messages(std::slice::from_ref(&response))?;
527 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
528 let summary: Vec<B::Message> = request_msgs
529 .into_iter()
530 .map(|msg| self.backend.to_system_message(msg))
531 .collect();
532 let kept: Vec<B::Message> = self.incremental.drain(..).collect();
533 let (final_summary, final_kept) =
534 self.notify_compressed_subscriber(summary, kept).await?;
535 self.compressed = final_summary;
536 self.incremental = final_kept;
537 }
538 }
539 Ok(())
540 }
541 }
542 }
543}
544
545impl<B: ContextBackend> Message<RequestSubscribeChange<B::Message>> for AgentContext<B> {
550 type Reply = ();
551
552 async fn handle(
553 &mut self,
554 msg: RequestSubscribeChange<B::Message>,
555 _ctx: &mut Context<Self, Self::Reply>,
556 ) -> Self::Reply {
557 self.subscribers.insert(msg.recipient);
558 }
559}
560
561impl<B: ContextBackend> Message<RequestUnsubscribeChange<B::Message>> for AgentContext<B> {
562 type Reply = bool;
563
564 async fn handle(
565 &mut self,
566 msg: RequestUnsubscribeChange<B::Message>,
567 _ctx: &mut Context<Self, Self::Reply>,
568 ) -> Self::Reply {
569 self.subscribers.remove(&msg.recipient)
570 }
571}
572
573impl<B: ContextBackend> Message<RequestSubscribeCompressed<B::Message>> for AgentContext<B> {
574 type Reply = ();
575
576 async fn handle(
577 &mut self,
578 msg: RequestSubscribeCompressed<B::Message>,
579 _ctx: &mut Context<Self, Self::Reply>,
580 ) -> Self::Reply {
581 self.on_compressed = Some(msg.recipient);
582 }
583}
584
585impl<B: ContextBackend> Message<RequestUnsubscribeCompressed> for AgentContext<B> {
586 type Reply = ();
587
588 async fn handle(
589 &mut self,
590 _msg: RequestUnsubscribeCompressed,
591 _ctx: &mut Context<Self, Self::Reply>,
592 ) -> Self::Reply {
593 self.on_compressed = None;
594 }
595}
596
597impl<B: ContextBackend> Message<RequestEstimateTokens> for AgentContext<B> {
602 type Reply = usize;
603
604 async fn handle(
605 &mut self,
606 _msg: RequestEstimateTokens,
607 _ctx: &mut Context<Self, Self::Reply>,
608 ) -> Self::Reply {
609 let all: Vec<B::Message> = self
610 .immutable
611 .iter()
612 .chain(self.compressed.iter())
613 .chain(self.incremental.iter())
614 .cloned()
615 .collect();
616 self.backend.estimate_tokens(&all).await.unwrap_or(0)
617 }
618}
619
620impl<B: ContextBackend> Message<RequestToJsonl> for AgentContext<B> {
621 type Reply = Result<String, AgentError>;
622
623 async fn handle(
624 &mut self,
625 _msg: RequestToJsonl,
626 _ctx: &mut Context<Self, Self::Reply>,
627 ) -> Self::Reply {
628 let lines: Vec<String> = self
629 .immutable
630 .iter()
631 .chain(self.compressed.iter())
632 .chain(self.incremental.iter())
633 .map(|m| self.backend.message_to_jsonl(m))
634 .collect::<Result<_, _>>()?;
635 Ok(lines.join("\n"))
636 }
637}
638
639impl<B: ContextBackend> Message<RequestFromJsonl> for AgentContext<B> {
640 type Reply = Result<(), AgentError>;
641
642 async fn handle(
643 &mut self,
644 msg: RequestFromJsonl,
645 _ctx: &mut Context<Self, Self::Reply>,
646 ) -> Self::Reply {
647 for line in msg.jsonl.lines() {
648 let line = line.trim();
649 if line.is_empty() {
650 continue;
651 }
652 let message: B::Message = self.backend.message_from_jsonl(line)?;
653 self.incremental.push(message.clone());
654 self.notify_change(NotifyChange::Appended(message)).await;
655 }
656 Ok(())
657 }
658}