1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::Arc;
3
4use log;
5use tokio::sync::mpsc;
6
7use crate::backend::{GenerationResult, InferenceParams, LlmBackend};
8use crate::context::{plan_prune, prepare_context, ContextConfig, PruneStrategy};
9use crate::error::{CoreError, CoreResult};
10use crate::events::AgentEvent;
11use crate::messages::{Message, Role, ToolCall};
12use crate::template::{ChatMLTemplate, ChatTemplate};
13use crate::tools::{parse_tool_calls, ToolSchema};
14#[cfg(feature = "tools")]
15use crate::{
16 messages::ToolResult,
17 tools::{Tool, ToolOutput, ToolUpdateCallback},
18};
19
20const SUMMARY_MARKER: &str = "[Summary of earlier conversation]";
23
24const SUMMARY_MAX_TOKENS: u32 = 320;
26
27#[derive(Debug, Clone)]
29pub struct AgentConfig {
30 pub system_prompt: String,
32 pub inference_params: InferenceParams,
34 pub context_config: ContextConfig,
36 pub max_tool_iterations: usize,
40}
41
42impl Default for AgentConfig {
43 fn default() -> Self {
44 Self {
45 system_prompt: "You are a helpful assistant.".to_string(),
46 inference_params: InferenceParams::default(),
47 context_config: ContextConfig::default(),
48 max_tool_iterations: 8,
49 }
50 }
51}
52
53pub struct Agent {
81 config: AgentConfig,
82 messages: Vec<Message>,
83 #[cfg(feature = "tools")]
84 tools: Vec<Box<dyn Tool>>,
85 template: Arc<dyn ChatTemplate>,
86 abort: Arc<AtomicBool>,
87 msg_counter: u64,
88}
89
90impl Agent {
91 pub fn new(config: AgentConfig) -> Self {
93 Self::with_template(config, Arc::new(ChatMLTemplate))
94 }
95
96 pub fn with_template(config: AgentConfig, template: Arc<dyn ChatTemplate>) -> Self {
98 log::debug!(
99 "Agent created: system_prompt_len={}, max_ctx={}, max_resp={}, template={}",
100 config.system_prompt.len(),
101 config.context_config.max_context_tokens,
102 config.context_config.max_response_tokens,
103 template.name(),
104 );
105 Self {
106 config,
107 messages: Vec::new(),
108 #[cfg(feature = "tools")]
109 tools: Vec::new(),
110 template,
111 abort: Arc::new(AtomicBool::new(false)),
112 msg_counter: 0,
113 }
114 }
115
116 pub fn messages(&self) -> &[Message] {
118 &self.messages
119 }
120
121 pub fn config(&self) -> &AgentConfig {
123 &self.config
124 }
125
126 pub fn template(&self) -> &dyn ChatTemplate {
128 self.template.as_ref()
129 }
130
131 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
133 let prompt = prompt.into();
134 log::debug!("Agent system prompt updated: len={}", prompt.len());
135 self.config.system_prompt = prompt;
136 }
137
138 pub fn set_inference_params(&mut self, params: InferenceParams) {
140 log::debug!(
141 "Agent inference params: max_tokens={}, temp={}, ctx={}, threads={}",
142 params.max_tokens,
143 params.temperature,
144 params.context_size,
145 params.n_threads,
146 );
147 self.config.inference_params = params;
148 }
149
150 pub fn set_context_config(&mut self, config: ContextConfig) {
152 log::debug!(
153 "Agent context config: max_ctx={}, max_resp={}",
154 config.max_context_tokens,
155 config.max_response_tokens,
156 );
157 self.config.context_config = config;
158 }
159
160 pub fn set_prune_strategy(&mut self, strategy: PruneStrategy) {
162 log::debug!("Agent prune strategy: {strategy:?}");
163 self.config.context_config.prune_strategy = strategy;
164 }
165
166 pub fn set_pinned(&mut self, message_id: &str, pinned: bool) -> bool {
169 match self.messages.iter_mut().find(|m| m.id == message_id) {
170 Some(msg) => {
171 msg.pinned = pinned;
172 log::debug!("Agent message {message_id} pinned={pinned}");
173 true
174 }
175 None => {
176 log::warn!("set_pinned: message {message_id} not found");
177 false
178 }
179 }
180 }
181
182 pub fn set_template(&mut self, template: Arc<dyn ChatTemplate>) {
184 log::debug!("Agent template updated: {}", template.name());
185 self.template = template;
186 }
187
188 #[cfg(feature = "tools")]
192 pub fn set_tools(&mut self, tools: Vec<Box<dyn Tool>>) {
193 log::debug!("Agent tools set: count={}", tools.len());
194 self.tools = tools;
195 }
196
197 pub fn clear(&mut self) {
199 let count = self.messages.len();
200 self.messages.clear();
201 log::debug!("Agent conversation cleared: {count} messages removed");
202 }
203
204 pub fn replace_messages(&mut self, messages: Vec<Message>) {
209 log::debug!("Agent messages replaced: count={}", messages.len());
210 let max_loaded = messages
214 .iter()
215 .filter_map(|m| m.id.strip_prefix("msg-"))
216 .filter_map(|n| n.parse::<u64>().ok())
217 .max()
218 .unwrap_or(0);
219 self.msg_counter = self.msg_counter.max(max_loaded);
220 self.messages = messages;
221 }
222
223 pub fn abort(&self) {
225 log::debug!("Agent abort requested");
226 self.abort.store(true, Ordering::Relaxed);
227 }
228
229 pub fn abort_flag(&self) -> Arc<AtomicBool> {
231 self.abort.clone()
232 }
233
234 fn next_id(&mut self) -> String {
235 self.msg_counter += 1;
236 format!("msg-{}", self.msg_counter)
237 }
238
239 #[cfg(feature = "tools")]
240 fn tool_schemas(&self) -> Vec<ToolSchema> {
241 self.tools.iter().map(|t| t.schema()).collect()
242 }
243
244 #[cfg(not(feature = "tools"))]
246 fn tool_schemas(&self) -> Vec<ToolSchema> {
247 Vec::new()
248 }
249
250 pub async fn prompt(
266 &mut self,
267 text: impl Into<String>,
268 backend: Arc<dyn LlmBackend>,
269 tx: mpsc::UnboundedSender<AgentEvent>,
270 ) -> CoreResult<()> {
271 let text = text.into().trim().to_string();
272 if text.is_empty() {
273 return Err(CoreError::Agent("Empty message".into()));
274 }
275
276 self.abort.store(false, Ordering::Relaxed);
277
278 let user_msg = Message::user(self.next_id(), &text);
279 self.messages.push(user_msg.clone());
280
281 tx.send(AgentEvent::AgentStart).ok();
282 tx.send(AgentEvent::MessageStart {
283 message: user_msg.clone(),
284 })
285 .ok();
286 tx.send(AgentEvent::MessageEnd { message: user_msg }).ok();
287
288 self.compress_if_needed(&backend, &tx).await;
291
292 let mut new_messages: Vec<Message> = Vec::new();
295 #[cfg(feature = "tools")]
296 let has_tools = !self.tools.is_empty();
297 #[cfg(not(feature = "tools"))]
298 let has_tools = false;
299
300 for iteration in 0..self.config.max_tool_iterations {
301 tx.send(AgentEvent::TurnStart).ok();
302
303 let gen = match self.generate_once(backend.clone(), &tx).await {
304 Ok(gen) => gen,
305 Err(CoreError::Aborted) => {
306 log::info!("Agent::prompt: generation aborted by user");
309 let assistant_msg = Message::assistant(self.next_id(), "");
310 self.messages.push(assistant_msg.clone());
311 new_messages.push(assistant_msg.clone());
312 tx.send(AgentEvent::MessageEnd {
313 message: assistant_msg.clone(),
314 })
315 .ok();
316 tx.send(AgentEvent::TurnEnd {
317 message: assistant_msg,
318 tool_results: vec![],
319 })
320 .ok();
321 tx.send(AgentEvent::AgentEnd {
322 messages: new_messages,
323 })
324 .ok();
325 return Ok(());
326 }
327 Err(e) => {
328 log::error!("Agent::prompt: generation error: {e}");
330 if iteration == 0 {
334 self.messages.pop();
335 }
336 tx.send(AgentEvent::Error {
337 message: e.to_string(),
338 })
339 .ok();
340 tx.send(AgentEvent::AgentEnd { messages: vec![] }).ok();
341 return Ok(());
342 }
343 };
344
345 log::debug!(
346 "Agent::prompt: turn {} → {} tokens, {:.1} t/s, {:.1}ms ttft",
347 iteration,
348 gen.tokens_generated,
349 gen.tokens_per_sec,
350 gen.time_to_first_token_ms,
351 );
352
353 let mut assistant_msg = Message::assistant(self.next_id(), &gen.text);
354 let parsed = if has_tools {
355 parse_tool_calls(&gen.text)
356 } else {
357 Vec::new()
358 };
359 let tool_calls: Vec<ToolCall> = parsed
360 .iter()
361 .enumerate()
362 .map(|(i, p)| ToolCall {
363 id: format!("{}-call-{}", assistant_msg.id, i + 1),
364 name: p.name.clone(),
365 arguments: p.arguments.clone(),
366 })
367 .collect();
368 assistant_msg.tool_calls = tool_calls.clone();
369
370 self.messages.push(assistant_msg.clone());
371 new_messages.push(assistant_msg.clone());
372
373 tx.send(AgentEvent::GenerationStats {
374 tokens_generated: gen.tokens_generated,
375 prompt_tokens: gen.prompt_tokens,
376 tokens_per_sec: gen.tokens_per_sec,
377 time_to_first_token_ms: gen.time_to_first_token_ms,
378 generation_time_ms: gen.generation_time_ms,
379 })
380 .ok();
381 tx.send(AgentEvent::MessageEnd {
382 message: assistant_msg.clone(),
383 })
384 .ok();
385
386 if tool_calls.is_empty() {
388 tx.send(AgentEvent::TurnEnd {
389 message: assistant_msg,
390 tool_results: vec![],
391 })
392 .ok();
393 tx.send(AgentEvent::AgentEnd {
394 messages: new_messages,
395 })
396 .ok();
397 return Ok(());
398 }
399
400 #[cfg(feature = "tools")]
404 {
405 let aborted = self
406 .run_tool_calls(&tool_calls, assistant_msg, &mut new_messages, &tx)
407 .await;
408 if aborted {
409 tx.send(AgentEvent::AgentEnd {
410 messages: new_messages,
411 })
412 .ok();
413 return Ok(());
414 }
415 }
417 }
418
419 log::warn!(
421 "Agent::prompt: stopped after {} tool iterations",
422 self.config.max_tool_iterations
423 );
424 tx.send(AgentEvent::Warning {
425 message: format!(
426 "Stopped after {} tool iterations without a final answer",
427 self.config.max_tool_iterations
428 ),
429 })
430 .ok();
431 tx.send(AgentEvent::AgentEnd {
432 messages: new_messages,
433 })
434 .ok();
435 Ok(())
436 }
437
438 #[cfg(feature = "tools")]
442 async fn run_tool_calls(
443 &mut self,
444 tool_calls: &[ToolCall],
445 assistant_msg: Message,
446 new_messages: &mut Vec<Message>,
447 tx: &mpsc::UnboundedSender<AgentEvent>,
448 ) -> bool {
449 let mut tool_results: Vec<ToolResult> = Vec::new();
450 for call in tool_calls {
451 tx.send(AgentEvent::ToolExecStart {
452 tool_call_id: call.id.clone(),
453 tool_name: call.name.clone(),
454 args: call.arguments.clone(),
455 })
456 .ok();
457
458 let (content, is_error) = match self.execute_tool(call, tx).await {
459 Ok(out) => (out.content, false),
460 Err(e) => {
461 log::warn!("Agent::prompt: tool '{}' failed: {e}", call.name);
462 (e.to_string(), true)
463 }
464 };
465
466 let result = ToolResult {
467 tool_call_id: call.id.clone(),
468 tool_name: call.name.clone(),
469 content: content.clone(),
470 is_error,
471 };
472 tx.send(AgentEvent::ToolExecEnd {
473 tool_call_id: call.id.clone(),
474 tool_name: call.name.clone(),
475 result: result.clone(),
476 })
477 .ok();
478
479 let result_msg =
480 Message::tool_result(self.next_id(), &call.id, &call.name, content, is_error);
481 self.messages.push(result_msg.clone());
482 new_messages.push(result_msg.clone());
483 tx.send(AgentEvent::MessageStart {
484 message: result_msg.clone(),
485 })
486 .ok();
487 tx.send(AgentEvent::MessageEnd {
488 message: result_msg,
489 })
490 .ok();
491
492 tool_results.push(result);
493 }
494
495 tx.send(AgentEvent::TurnEnd {
496 message: assistant_msg,
497 tool_results,
498 })
499 .ok();
500
501 self.abort.load(Ordering::Relaxed)
502 }
503
504 pub fn prompt_stream(
550 &mut self,
551 text: impl Into<String>,
552 backend: Arc<dyn LlmBackend>,
553 ) -> (
554 mpsc::UnboundedReceiver<AgentEvent>,
555 impl std::future::Future<Output = CoreResult<()>> + '_,
556 ) {
557 let (tx, rx) = mpsc::unbounded_channel();
558 let text = text.into();
559 let fut = async move { self.prompt(text, backend, tx).await };
560 (rx, fut)
561 }
562
563 async fn generate_once(
570 &self,
571 backend: Arc<dyn LlmBackend>,
572 tx: &mpsc::UnboundedSender<AgentEvent>,
573 ) -> CoreResult<GenerationResult> {
574 let messages = self.messages.clone();
575 let system_prompt = self.config.system_prompt.clone();
576 let ctx_config = self.config.context_config.clone();
577 let tool_schemas = self.tool_schemas();
578 let params = self.config.inference_params.clone();
579 let abort = self.abort.clone();
580 let max_ctx = self.config.context_config.max_context_tokens;
581 let template = self.template.clone();
582 let token_tx = tx.clone();
583 let budget_tx = tx.clone();
584
585 log::debug!(
586 "Agent::generate_once: spawning blocking (max_tokens={}, temp={}, ctx={}, threads={})",
587 params.max_tokens,
588 params.temperature,
589 params.context_size,
590 params.n_threads,
591 );
592
593 let handle = tokio::task::spawn_blocking(move || {
594 if !backend.is_ready() {
595 return Err(CoreError::Backend("No model loaded".into()));
596 }
597
598 let prepared = prepare_context(
599 template.as_ref(),
600 &system_prompt,
601 &messages,
602 &tool_schemas,
603 &ctx_config,
604 &|text| backend.tokenize_count(text).unwrap_or(0),
605 )?;
606
607 log::debug!(
608 "Context prepared: tokens={}, kept={}, pruned={}",
609 prepared.token_count,
610 prepared.messages_included,
611 prepared.messages_pruned,
612 );
613
614 budget_tx
615 .send(AgentEvent::ContextBudget {
616 used_tokens: prepared.token_count,
617 max_tokens: max_ctx,
618 messages_in_context: prepared.messages_included,
619 messages_pruned: prepared.messages_pruned,
620 })
621 .ok();
622
623 backend.generate(
624 &prepared.prompt,
625 ¶ms,
626 abort,
627 Box::new(move |token, count, tps| {
628 token_tx
629 .send(AgentEvent::MessageDelta {
630 delta: token.to_string(),
631 tokens_generated: count,
632 tokens_per_sec: tps,
633 })
634 .ok();
635 }),
636 )
637 });
638
639 handle.await.map_err(|e| {
640 log::error!("Agent::generate_once: blocking task panicked: {e}");
641 CoreError::Agent(format!("Inference task failed: {e}"))
642 })?
643 }
644
645 #[cfg(feature = "tools")]
650 async fn execute_tool(
651 &self,
652 call: &ToolCall,
653 tx: &mpsc::UnboundedSender<AgentEvent>,
654 ) -> CoreResult<ToolOutput> {
655 let Some(tool) = self.tools.iter().find(|t| t.name() == call.name) else {
656 return Err(CoreError::Tool(format!("unknown tool: {}", call.name)));
657 };
658
659 let update_tx = tx.clone();
660 let tool_call_id = call.id.clone();
661 let tool_name = call.name.clone();
662 let on_update: ToolUpdateCallback = Box::new(move |partial: &str| {
663 update_tx
664 .send(AgentEvent::ToolExecUpdate {
665 tool_call_id: tool_call_id.clone(),
666 tool_name: tool_name.clone(),
667 partial: partial.to_string(),
668 })
669 .ok();
670 });
671
672 tool.execute(&call.id, call.arguments.clone(), Some(on_update))
673 .await
674 }
675
676 async fn compress_if_needed(
684 &mut self,
685 backend: &Arc<dyn LlmBackend>,
686 tx: &mpsc::UnboundedSender<AgentEvent>,
687 ) {
688 if self.config.context_config.prune_strategy != PruneStrategy::Summarize {
689 return;
690 }
691
692 let messages = self.messages.clone();
693 let system_prompt = self.config.system_prompt.clone();
694 let tools = self.tool_schemas();
695 let ctx_config = self.config.context_config.clone();
696 let template = self.template.clone();
697 let abort = self.abort.clone();
698 let params = self.config.inference_params.clone();
699 let backend = backend.clone();
700
701 let outcome = tokio::task::spawn_blocking(move || -> Option<(Vec<usize>, String)> {
703 if !backend.is_ready() {
704 return None;
705 }
706 let counter = |t: &str| backend.tokenize_count(t).unwrap_or(0);
707 let plan = plan_prune(
708 template.as_ref(),
709 &system_prompt,
710 &messages,
711 &tools,
712 &ctx_config,
713 &counter,
714 )
715 .ok()?;
716 if plan.dropped.is_empty() {
717 return None; }
719
720 let mut remove: Vec<usize> = plan.dropped.iter().flat_map(|r| r.clone()).collect();
723 let prior_summary = messages
724 .iter()
725 .position(|m| m.pinned && m.content.starts_with(SUMMARY_MARKER));
726 let prior_body = prior_summary.map(|i| {
727 remove.push(i);
728 messages[i]
729 .content
730 .strip_prefix(SUMMARY_MARKER)
731 .unwrap_or(&messages[i].content)
732 .trim()
733 .to_string()
734 });
735 remove.sort_unstable();
736 remove.dedup();
737
738 let transcript = render_transcript(&messages, &remove);
739 let mut body = String::new();
740 if let Some(prev) = prior_body.filter(|s| !s.is_empty()) {
741 body.push_str("Earlier summary:\n");
742 body.push_str(&prev);
743 body.push_str("\n\n");
744 }
745 body.push_str("Conversation excerpt:\n");
746 body.push_str(&transcript);
747
748 let instruction = "You compress conversation history. Summarize the \
749 material below into a concise note that preserves key facts, \
750 decisions, names, and unresolved questions. Reply with only the \
751 summary.";
752 let req = Message::user("summary-req", format!("{instruction}\n\n{body}"));
753 let prompt = template.format(
754 "You summarize conversations faithfully and concisely.",
755 std::slice::from_ref(&req),
756 &[],
757 );
758
759 let sum_params = InferenceParams {
760 max_tokens: SUMMARY_MAX_TOKENS,
761 ..params
762 };
763 let gen = backend
764 .generate(&prompt, &sum_params, abort, Box::new(|_, _, _| {}))
765 .ok()?;
766 let summary = gen.text.trim().to_string();
767 if summary.is_empty() {
768 return None;
769 }
770 Some((remove, summary))
771 })
772 .await;
773
774 let Some((remove, summary)) = outcome.ok().flatten() else {
775 return;
776 };
777
778 self.fold_into_summary(&remove, summary);
779 tx.send(AgentEvent::Warning {
780 message: format!(
781 "Summarized {} earlier message(s) to fit the context window",
782 remove.len()
783 ),
784 })
785 .ok();
786 }
787
788 fn fold_into_summary(&mut self, remove: &[usize], summary: String) {
791 if remove.is_empty() {
792 return;
793 }
794 let insert_at = *remove.iter().min().unwrap();
795 let mut sorted = remove.to_vec();
796 sorted.sort_unstable();
797 for &i in sorted.iter().rev() {
798 if i < self.messages.len() {
799 self.messages.remove(i);
800 }
801 }
802 let summary_msg =
803 Message::user(self.next_id(), format!("{SUMMARY_MARKER}\n{summary}")).pinned();
804 let at = insert_at.min(self.messages.len());
805 self.messages.insert(at, summary_msg);
806 log::info!(
807 "Folded {} messages into a pinned summary at index {at}",
808 remove.len()
809 );
810 }
811}
812
813fn render_transcript(messages: &[Message], indices: &[usize]) -> String {
815 indices
816 .iter()
817 .filter_map(|&i| messages.get(i))
818 .map(|m| {
819 let role = match m.role {
820 Role::User => "User",
821 Role::Assistant | Role::ToolCall => "Assistant",
822 Role::ToolResult => "Tool",
823 Role::System => "System",
824 };
825 format!("{role}: {}", m.content)
826 })
827 .collect::<Vec<_>>()
828 .join("\n")
829}