1use crate::agent::{Agent, EndStrategy};
6use crate::context::{generate_run_id, RunContext, RunUsage, UsageLimits};
7use crate::errors::{AgentRunError, OutputParseError, OutputValidationError};
8use chrono::Utc;
9use serde_json::Value as JsonValue;
10use serdes_ai_core::messages::{RetryPromptPart, ToolReturnPart, UserContent};
11use serdes_ai_core::{
12 FinishReason, ModelRequest, ModelRequestPart, ModelResponse, ModelResponsePart, ModelSettings,
13};
14use serdes_ai_models::ModelRequestParameters;
15use serdes_ai_tools::{ToolError, ToolReturn};
16use std::sync::Arc;
17use tokio_util::sync::CancellationToken;
18
19#[derive(Debug, Clone, Default)]
21pub enum CompressionStrategy {
22 #[default]
24 Truncate,
25 Summarize,
27}
28
29#[derive(Debug, Clone)]
31pub struct ContextCompression {
32 pub strategy: CompressionStrategy,
34 pub threshold: f64,
36 pub target_tokens: usize,
38}
39
40impl Default for ContextCompression {
41 fn default() -> Self {
42 Self {
43 strategy: CompressionStrategy::Truncate,
44 threshold: 0.75,
45 target_tokens: 30_000,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Default)]
52pub struct RunOptions {
53 pub model_settings: Option<ModelSettings>,
55 pub message_history: Option<Vec<ModelRequest>>,
57 pub usage_limits: Option<crate::context::UsageLimits>,
59 pub metadata: Option<JsonValue>,
61 pub compression: Option<ContextCompression>,
63}
64
65impl RunOptions {
66 pub fn new() -> Self {
68 Self::default()
69 }
70
71 pub fn model_settings(mut self, settings: ModelSettings) -> Self {
73 self.model_settings = Some(settings);
74 self
75 }
76
77 pub fn message_history(mut self, history: Vec<ModelRequest>) -> Self {
79 self.message_history = Some(history);
80 self
81 }
82
83 pub fn metadata(mut self, metadata: JsonValue) -> Self {
85 self.metadata = Some(metadata);
86 self
87 }
88
89 pub fn with_compression(mut self, config: ContextCompression) -> Self {
91 self.compression = Some(config);
92 self
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct AgentRunResult<Output> {
99 pub output: Output,
101 pub messages: Vec<ModelRequest>,
103 pub responses: Vec<ModelResponse>,
105 pub usage: RunUsage,
107 pub run_id: String,
109 pub finish_reason: FinishReason,
111 pub metadata: Option<JsonValue>,
113}
114
115impl<Output> AgentRunResult<Output> {
116 pub fn output(&self) -> &Output {
118 &self.output
119 }
120
121 pub fn into_output(self) -> Output {
123 self.output
124 }
125}
126
127pub struct AgentRun<'a, Deps, Output> {
135 agent: &'a Agent<Deps, Output>,
136 #[allow(dead_code)]
137 deps: Arc<Deps>,
138 state: AgentRunState<Output>,
139 ctx: RunContext<Deps>,
140 run_usage_limits: Option<UsageLimits>,
141 cancel_token: Option<CancellationToken>,
143}
144
145struct AgentRunState<Output> {
146 messages: Vec<ModelRequest>,
147 responses: Vec<ModelResponse>,
148 usage: RunUsage,
149 run_id: String,
150 step: u32,
151 output_retries: u32,
152 final_output: Option<Output>,
153 finished: bool,
154 finish_reason: Option<FinishReason>,
155}
156
157#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum StepResult {
160 Continue,
162 ToolsExecuted(usize),
164 OutputReady,
166 RetryingOutput,
168 Finished,
170}
171
172impl<'a, Deps, Output> AgentRun<'a, Deps, Output>
173where
174 Deps: Send + Sync + 'static,
175 Output: Send + Sync + 'static,
176{
177 pub async fn new(
179 agent: &'a Agent<Deps, Output>,
180 prompt: UserContent,
181 deps: Deps,
182 options: RunOptions,
183 ) -> Result<Self, AgentRunError> {
184 let run_id = generate_run_id();
185 let deps = Arc::new(deps);
186
187 let model_settings = options
188 .model_settings
189 .unwrap_or_else(|| agent.model_settings.clone());
190
191 let ctx = RunContext {
192 deps: deps.clone(),
193 run_id: run_id.clone(),
194 start_time: Utc::now(),
195 model_name: agent.model().name().to_string(),
196 model_settings: model_settings.clone(),
197 tool_name: None,
198 tool_call_id: None,
199 retry_count: 0,
200 metadata: options.metadata.clone(),
201 };
202
203 let mut messages = options.message_history.unwrap_or_default();
205
206 let system_prompt = agent.build_system_prompt(&ctx).await;
208 if !system_prompt.is_empty() {
209 let mut req = ModelRequest::new();
210 req.add_system_prompt(system_prompt);
211 messages.push(req);
212 }
213
214 let mut user_req = ModelRequest::new();
216 user_req.add_user_prompt(prompt);
217 messages.push(user_req);
218
219 Ok(Self {
220 agent,
221 deps,
222 state: AgentRunState {
223 messages,
224 responses: Vec::new(),
225 usage: RunUsage::new(),
226 run_id,
227 step: 0,
228 output_retries: 0,
229 final_output: None,
230 finished: false,
231 finish_reason: None,
232 },
233 ctx,
234 run_usage_limits: options.usage_limits,
235 cancel_token: None,
236 })
237 }
238
239 pub async fn new_with_cancel(
264 agent: &'a Agent<Deps, Output>,
265 prompt: UserContent,
266 deps: Deps,
267 options: RunOptions,
268 cancel_token: CancellationToken,
269 ) -> Result<Self, AgentRunError> {
270 let run_id = generate_run_id();
271 let deps = Arc::new(deps);
272
273 let model_settings = options
274 .model_settings
275 .unwrap_or_else(|| agent.model_settings.clone());
276
277 let ctx = RunContext {
278 deps: deps.clone(),
279 run_id: run_id.clone(),
280 start_time: Utc::now(),
281 model_name: agent.model().name().to_string(),
282 model_settings: model_settings.clone(),
283 tool_name: None,
284 tool_call_id: None,
285 retry_count: 0,
286 metadata: options.metadata.clone(),
287 };
288
289 let mut messages = options.message_history.unwrap_or_default();
291
292 let system_prompt = agent.build_system_prompt(&ctx).await;
294 if !system_prompt.is_empty() {
295 let mut req = ModelRequest::new();
296 req.add_system_prompt(system_prompt);
297 messages.push(req);
298 }
299
300 let mut user_req = ModelRequest::new();
302 user_req.add_user_prompt(prompt);
303 messages.push(user_req);
304
305 Ok(Self {
306 agent,
307 deps,
308 state: AgentRunState {
309 messages,
310 responses: Vec::new(),
311 usage: RunUsage::new(),
312 run_id,
313 step: 0,
314 output_retries: 0,
315 final_output: None,
316 finished: false,
317 finish_reason: None,
318 },
319 ctx,
320 run_usage_limits: options.usage_limits,
321 cancel_token: Some(cancel_token),
322 })
323 }
324
325 pub async fn run_to_completion(mut self) -> Result<AgentRunResult<Output>, AgentRunError> {
327 while !self.state.finished {
328 self.step().await?;
329 }
330 self.finalize()
331 }
332
333 pub async fn step(&mut self) -> Result<StepResult, AgentRunError> {
338 if self.state.finished {
339 return Ok(StepResult::Finished);
340 }
341
342 if let Some(ref token) = self.cancel_token {
344 if token.is_cancelled() {
345 return Err(AgentRunError::Cancelled);
346 }
347 }
348
349 self.state.step += 1;
350
351 if let Some(limits) = &self.agent.usage_limits {
353 limits.check(&self.state.usage)?;
354 limits.check_time(self.ctx.elapsed_seconds() as u64)?;
355 }
356
357 if let Some(limits) = &self.run_usage_limits {
358 limits.check(&self.state.usage)?;
359 limits.check_time(self.ctx.elapsed_seconds() as u64)?;
360 }
361
362 let tool_defs = self.agent.tool_definitions();
364
365 let params = ModelRequestParameters::new()
367 .with_tools_arc(tool_defs)
368 .with_allow_text(true);
369
370 let messages = self.process_history().await;
372
373 let response = self
375 .agent
376 .model()
377 .request(&messages, &self.ctx.model_settings, ¶ms)
378 .await?;
379
380 if let Some(usage) = &response.usage {
382 self.state.usage.add_request(usage.clone());
383 }
384
385 if response.finish_reason.is_some() {
387 self.state.finish_reason = response.finish_reason;
388 }
389 self.state.responses.push(response.clone());
390
391 self.process_response(response).await
393 }
394
395 async fn process_history(&self) -> Vec<ModelRequest> {
396 let mut messages = self.state.messages.clone();
397
398 for processor in &self.agent.history_processors {
400 messages = processor.process(&self.ctx, messages).await;
401 }
402
403 messages
404 }
405
406 async fn process_response(
407 &mut self,
408 response: ModelResponse,
409 ) -> Result<StepResult, AgentRunError> {
410 let mut tool_calls = Vec::new();
411 let mut found_output = None;
412
413 for part in &response.parts {
414 match part {
415 ModelResponsePart::Text(text) => {
416 if !text.content.is_empty() {
417 match self.agent.output_schema.parse_text(&text.content) {
419 Ok(output) => found_output = Some(output),
420 Err(OutputParseError::NotFound) => {}
421 Err(_) => {} }
423 }
424 }
425 ModelResponsePart::ToolCall(tc) => {
426 if self.agent.is_output_tool(&tc.tool_name) {
428 let args = tc.args.to_json();
429 if let Ok(output) = self
430 .agent
431 .output_schema
432 .parse_tool_call(&tc.tool_name, &args)
433 {
434 found_output = Some(output);
435 continue;
436 }
437 }
438
439 tool_calls.push(tc.clone());
441 }
442 ModelResponsePart::Thinking(_) => {
443 }
445 ModelResponsePart::File(_) => {
446 }
448 ModelResponsePart::BuiltinToolCall(_) => {
449 }
451 }
452 }
453
454 if !tool_calls.is_empty() {
460 let count = tool_calls.len();
461 let returns = self.execute_tool_calls(tool_calls).await;
462 self.add_tool_returns(returns)?;
463 return Ok(StepResult::ToolsExecuted(count));
464 }
465
466 if let Some(output) = found_output {
468 match self.validate_output(output).await {
469 Ok(validated) => {
470 self.state.final_output = Some(validated);
471
472 if self.agent.end_strategy == EndStrategy::Early {
474 self.state.finished = true;
475 return Ok(StepResult::OutputReady);
476 }
477 }
478 Err(e) => {
479 self.state.output_retries += 1;
480 if self.state.output_retries > self.agent.max_output_retries {
481 return Err(AgentRunError::OutputValidationFailed(e));
482 }
483
484 self.add_retry_message(e)?;
486 return Ok(StepResult::RetryingOutput);
487 }
488 }
489 }
490
491 if response.finish_reason == Some(FinishReason::Stop) {
493 if self.state.final_output.is_some() {
494 self.state.finished = true;
495 return Ok(StepResult::Finished);
496 }
497
498 if let Some(text) = response.parts.iter().find_map(|p| match p {
500 ModelResponsePart::Text(t) if !t.content.is_empty() => Some(&t.content),
501 _ => None,
502 }) {
503 if let Ok(output) = self.agent.output_schema.parse_text(text) {
505 match self.validate_output(output).await {
506 Ok(validated) => {
507 self.state.final_output = Some(validated);
508 self.state.finished = true;
509 return Ok(StepResult::Finished);
510 }
511 Err(e) => {
512 return Err(AgentRunError::OutputValidationFailed(e));
513 }
514 }
515 }
516 }
517
518 return Err(AgentRunError::UnexpectedStop);
519 }
520
521 Ok(StepResult::Continue)
522 }
523
524 async fn execute_tool_calls(
525 &mut self,
526 calls: Vec<serdes_ai_core::messages::ToolCallPart>,
527 ) -> Vec<(String, Option<String>, Result<ToolReturn, ToolError>)> {
528 if self.agent.parallel_tool_calls {
529 self.execute_tools_parallel(calls).await
530 } else {
531 self.execute_tools_sequential(calls).await
532 }
533 }
534
535 async fn execute_tools_sequential(
539 &mut self,
540 calls: Vec<serdes_ai_core::messages::ToolCallPart>,
541 ) -> Vec<(String, Option<String>, Result<ToolReturn, ToolError>)> {
542 let mut returns = Vec::new();
543
544 for tc in calls {
545 if let Some(ref token) = self.cancel_token {
547 if token.is_cancelled() {
548 returns.push((
549 tc.tool_name.clone(),
550 tc.tool_call_id.clone(),
551 Err(ToolError::Cancelled),
552 ));
553 continue;
554 }
555 }
556
557 self.state.usage.record_tool_call();
558
559 let tool = match self.agent.find_tool(&tc.tool_name) {
560 Some(t) => t,
561 None => {
562 returns.push((
563 tc.tool_name.clone(),
564 tc.tool_call_id.clone(),
565 Err(ToolError::NotFound(tc.tool_name.clone())),
566 ));
567 continue;
568 }
569 };
570
571 let tool_ctx = self.ctx.for_tool(&tc.tool_name, tc.tool_call_id.clone());
573
574 let args = tc.args.to_json();
576 let mut retries = 0;
577 let result = loop {
578 match tool.executor.execute(args.clone(), &tool_ctx).await {
579 Ok(r) => break Ok(r),
580 Err(e) if e.is_retryable() && retries < tool.max_retries => {
581 retries += 1;
582 continue;
583 }
584 Err(e) => break Err(e),
585 }
586 };
587
588 returns.push((tc.tool_name.clone(), tc.tool_call_id.clone(), result));
589 }
590
591 returns
592 }
593
594 async fn execute_tools_parallel(
596 &mut self,
597 calls: Vec<serdes_ai_core::messages::ToolCallPart>,
598 ) -> Vec<(String, Option<String>, Result<ToolReturn, ToolError>)> {
599 use futures::future::join_all;
600
601 for _ in &calls {
603 self.state.usage.record_tool_call();
604 }
605
606 let futures: Vec<_> = calls
608 .into_iter()
609 .map(|tc| {
610 let tool_name = tc.tool_name.clone();
611 let tool_call_id = tc.tool_call_id.clone();
612 let args = tc.args.to_json();
613
614 let tool = self.agent.find_tool(&tc.tool_name).cloned();
616 let tool_ctx = self.ctx.for_tool(&tc.tool_name, tc.tool_call_id.clone());
617
618 async move {
619 let tool = match tool {
620 Some(t) => t,
621 None => {
622 return (
623 tool_name.clone(),
624 tool_call_id,
625 Err(ToolError::NotFound(tool_name)),
626 );
627 }
628 };
629
630 let max_retries = tool.max_retries;
632 let executor = tool.executor;
633 let mut retries = 0;
634
635 let result = loop {
636 match executor.execute(args.clone(), &tool_ctx).await {
637 Ok(r) => break Ok(r),
638 Err(e) if e.is_retryable() && retries < max_retries => {
639 retries += 1;
640 continue;
641 }
642 Err(e) => break Err(e),
643 }
644 };
645
646 (tool_name, tool_call_id, result)
647 }
648 })
649 .collect();
650
651 if let Some(max_concurrent) = self.agent.max_concurrent_tools {
653 self.execute_with_semaphore(futures, max_concurrent).await
654 } else {
655 join_all(futures).await
656 }
657 }
658
659 async fn execute_with_semaphore<F, T>(&self, futures: Vec<F>, max_concurrent: usize) -> Vec<T>
664 where
665 F: std::future::Future<Output = T> + Send,
666 T: Send,
667 {
668 use futures::future::join_all;
669 use std::sync::Arc;
670 use tokio::sync::Semaphore;
671
672 let semaphore = Arc::new(Semaphore::new(max_concurrent));
673
674 let wrapped_futures: Vec<_> = futures
675 .into_iter()
676 .map(|fut| {
677 let sem = Arc::clone(&semaphore);
678 async move {
679 let _permit = sem.acquire().await.expect("Semaphore closed unexpectedly");
681 fut.await
682 }
684 })
685 .collect();
686
687 join_all(wrapped_futures).await
689 }
690
691 fn add_tool_returns(
692 &mut self,
693 returns: Vec<(String, Option<String>, Result<ToolReturn, ToolError>)>,
694 ) -> Result<(), AgentRunError> {
695 if let Some(last_response) = self.state.responses.last() {
699 let mut response_req = ModelRequest::new();
700 response_req
701 .parts
702 .push(ModelRequestPart::ModelResponse(Box::new(
703 last_response.clone(),
704 )));
705 self.state.messages.push(response_req);
706 }
707
708 let mut req = ModelRequest::new();
709
710 for (tool_name, tool_call_id, result) in returns {
711 match result {
712 Ok(ret) => {
713 let mut part = ToolReturnPart::new(&tool_name, ret.content);
714 if let Some(id) = tool_call_id {
715 part = part.with_tool_call_id(id);
716 }
717 req.parts.push(ModelRequestPart::ToolReturn(part));
718 }
719 Err(e) => {
720 let mut part = RetryPromptPart::new(format!("Tool error: {}", e));
721 part = part.with_tool_name(&tool_name);
722 if let Some(id) = tool_call_id {
723 part = part.with_tool_call_id(id);
724 }
725 req.parts.push(ModelRequestPart::RetryPrompt(part));
726 }
727 }
728 }
729
730 if !req.parts.is_empty() {
731 self.state.messages.push(req);
732 }
733
734 Ok(())
735 }
736
737 fn add_retry_message(&mut self, error: OutputValidationError) -> Result<(), AgentRunError> {
738 let mut req = ModelRequest::new();
739 let part = RetryPromptPart::new(error.retry_message());
740 req.parts.push(ModelRequestPart::RetryPrompt(part));
741 self.state.messages.push(req);
742 Ok(())
743 }
744
745 async fn validate_output(&self, output: Output) -> Result<Output, OutputValidationError> {
746 let mut output = output;
747 for validator in &self.agent.output_validators {
748 output = validator.validate(output, &self.ctx).await?;
749 }
750 Ok(output)
751 }
752
753 fn finalize(self) -> Result<AgentRunResult<Output>, AgentRunError> {
754 let output = self.state.final_output.ok_or(AgentRunError::NoOutput)?;
755
756 Ok(AgentRunResult {
757 output,
758 messages: self.state.messages,
759 responses: self.state.responses,
760 usage: self.state.usage,
761 run_id: self.state.run_id,
762 finish_reason: self.state.finish_reason.unwrap_or(FinishReason::Stop),
763 metadata: self.ctx.metadata.clone(),
764 })
765 }
766
767 pub fn messages(&self) -> &[ModelRequest] {
769 &self.state.messages
770 }
771
772 pub fn usage(&self) -> &RunUsage {
774 &self.state.usage
775 }
776
777 pub fn run_id(&self) -> &str {
779 &self.state.run_id
780 }
781
782 pub fn is_finished(&self) -> bool {
784 self.state.finished
785 }
786
787 pub fn step_number(&self) -> u32 {
789 self.state.step
790 }
791
792 pub fn cancel(&self) {
801 if let Some(ref token) = self.cancel_token {
802 token.cancel();
803 }
804 }
805
806 pub fn is_cancelled(&self) -> bool {
811 self.cancel_token
812 .as_ref()
813 .map(|t| t.is_cancelled())
814 .unwrap_or(false)
815 }
816
817 pub fn cancellation_token(&self) -> Option<&CancellationToken> {
822 self.cancel_token.as_ref()
823 }
824}
825
826#[cfg(test)]
827mod tests {
828 use super::*;
829
830 #[test]
831 fn test_run_options_default() {
832 let options = RunOptions::default();
833 assert!(options.model_settings.is_none());
834 assert!(options.message_history.is_none());
835 }
836
837 #[test]
838 fn test_run_options_builder() {
839 let options = RunOptions::new()
840 .model_settings(ModelSettings::new().temperature(0.5))
841 .metadata(serde_json::json!({"key": "value"}));
842
843 assert!(options.model_settings.is_some());
844 assert!(options.metadata.is_some());
845 }
846
847 #[test]
848 fn test_step_result_eq() {
849 assert_eq!(StepResult::Continue, StepResult::Continue);
850 assert_eq!(StepResult::ToolsExecuted(2), StepResult::ToolsExecuted(2));
851 assert_ne!(StepResult::ToolsExecuted(1), StepResult::ToolsExecuted(2));
852 }
853}